0%

NLP-Bert

BERT全名为(Bidirectional Encoder Representations from Transformers), 在NLP领域作为Word2vec的替代者,将NLP中的$11$项任务精度向前大大推进,作为一名算法工程师,非常有必要好好研究一下。

BERT的背景和特点

背景

​ 在CV领域,我们一般会使用基于Imagenet预训练的模型权重作为其他任务的basenet,这样不仅能够提高模型的泛化能力并且可以减少训练时间,使训练过程更加平稳,这是预训练带来的好处。在NLP领域,BERT就是要做这样的事情,在通用数据上学习词向量,并能够迁移到其他下游子任务。

特点

  1. Transformer结构的极致使用
  2. 在预训练时,使用MLM(Mask Language Model)NSP(Next Sentence Prediction)作为训练目标捕获词和句的关系,句和句的关系
    1. MLM:随机mask一些词进行预测,类似于完形填空,该任务主要捕获序列不同词之间的关系
    2. NSP:预测两个句子是否是连续前后句关系,给定两个句子,预测后面的句子是否是前面的句子的下一句,该任务主要捕获不同序列之间的关系
  3. 使用海量的数据进行训练,可以进行分布式训练
  4. 可以方便用于其他任务的迁移和微调,例如句子相关性,机器翻译,语音识别,关键词提取等NLP相关任务

BERT有可能存在的问题

  1. 由于transformer的利用,使用position embding来捕获序列中元素的位置关系,但是在巨量数据中,transformer对元素位置的信息表征不强

BERT如何工作

Bert使用了transformer结构,一种google提出的注意力学习机制,用以学习文本中不同token之间的上下文关系。在transformer中包含encoder和decoder部分,由于BERT的目的是生成语言模型,因此只需要transformer中的encoder部分。关于transformer,相关详细讲解以前有写过👉👉👉

transformer改变了原来使用RNN的方式,在RNN中为了捕获当前词语和其左右两侧词语的关系一般使用BiRNN设计,并且RNN不能进行并行设计,无法利用GPU加速运算。在Transformer中输入不再入RNN相同的是单个token而是整个sequence,并且加入输入的sequence中每个token的位置信息,这样每个token自然可以接收到它的上下文信息,并且这种设计可以使用GPU进行加速运算。

MLM

如开始所述,MLM任务类似于我们做过的完形填空,是为了学习token的上下文关系。在输入的序列中选择所有token的$15\%$的比例进行mask处理,在这$15\%$中选择$80\%$使用MASK进行替换,$15 \%$从此表中随机选择其它token替换,$10\%$保持原来的token不变,这样设计的MLM任务,既保证了当前的模型见过真正的词是什么,又见过随机替换的次和空白词,有助于模型能够学习到根据其余词推断当前词的能力。

在网络设计上,在模型输出的最后一层接一个全连接层将输出特征的维度转化为词表大小,然后在使用softmax和交叉熵损失来进行设计,和通用的分类网络最后部分相同,在计算损失的时候,只计算数据中被MASK的部分的损失。

NSP

NSP是为了学习句子之间上下文关系,通俗的将就是,给你两个句子,你来判断这两个句子是不是挨着的,如果是挨着的那么就是1,否则就是0。在BERT中模型接收的是成对的句子,并学习预测成对的第二个句子是否是第一个句子在原始文档中的后续句子。 在数据制作时,第二个句子有$50\%$的概率是真是的第一个句子的后续句子,$50\%$的概率是从其余的文档中随机抽选的句子,并且有一个先验就是,随机选择的句子和第一个句子是没啥关系的。为了让模型能够顺利学习两个句子之间的上下文关系,在数据原始数据中要进行一些处理,

  1. 在输入的序列的头部插入入CLS这个token,在第一个句子结束后插入SEP,如果两个句子满足上下文关系那么$CLS=1$,否则为$0$;
  2. 在输入的向量中加入token typeembeding,标识当前的token归属于第一个句子还是第二个句子。
  3. 加入位置向量,用来表征输入的token的位置信息。

在网络设计上,在模型输出的最后一层,取第一个位置进行全连接转化维度为$2$,进行一个0/1分类的损失计算。

BERT代码研究

下面结合官方代码进行分析 https://github.com/google-research/bert

图中Trm对应一个Transformer Block模块。在Tansformer中,position embeding将序列中任意两个点之间的距离都转化成了1,因此在BERT中,得益于transformer的应用,可以捕获所有层中的左右两侧的语义信息。

数据处理

BERT中,输入的编码向量由三部分组成

Tokens

将输入的单词划分成一组有限的公共子词单元,也就是“拆词”,这个在英文在试有必要的,这样能够能在单词的有效性和字符的灵活性之间取得一个折中的平衡;但是在中文中,最小的单位就是一个字,不需要再进行拆分。官方代码位置 https://github.com/google-research/bert/blob/master/tokenization.py#L161,我对其中进行了部分修改,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import re
import unicodedata
import six
import tensorflow as tf
import sys
import traceback
import logging

reload(sys)
sys.setdefaultencoding('utf8')

logger=logging

def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""

# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.

if not init_checkpoint:
return

m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return

model_name = m.group(1)

lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]

cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]

is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"

if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"

if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." %
(actual_flag, init_checkpoint, model_name, case_name,
opposite_flag))


def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")


def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""

# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")


def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with tf.gfile.GFile(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab


def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output


def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)


def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)


def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens

class FullTokenizer(object):
"""Runs end-to-end tokenziation."""

def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

def tokenize(self, text):
split_tokens = []
logger.debug("text is: {:s}".format(text))
basic_tokenizers = self.basic_tokenizer.tokenize(text)
logger.debug("basic_tokenizers is {}".format(basic_tokenizers))
for token in basic_tokenizers:
logger.debug("BasicTokenizer output: {:s}".format(token))
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
logger.debug("Splits token is {}".format(split_tokens))
return split_tokens

def tokenize_unicoder(self, text):
"""
[summary]
just token the textual part, do not change the rest
Arguments:
text {[string]} -- ["id\tdet1\tdet2\ttext"]

Returns:
details [list] -- [id,det1,det2,token1,token2,token3,token4...]
"""
try:
if isinstance(text,str):
split_text = text.strip().split("\t")
else:
split_text = text
logger.info("split text {}".format(split_text))
image_id, image_url, title_text = split_text
if not AbaseClient.exists(Params.abase_image_text_relevance_data.format(image_id)):
logger.error("{} do not exist in abase".format(image_id))
return []
split_tokens = [image_id,image_url]
for token in self.basic_tokenizer.tokenize(title_text):
logger.debug("tokenize_unicoder: BasicTokenizer output is --> {}".format(token))
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
except:
logger.error(traceback.format_exc())
return []


def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)

def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)


class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""

def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.

Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case

def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)

# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)

orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))

output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens

def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)

def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1

return ["".join(x) for x in output]

def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)

def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True

return False

def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)


class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""

def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word

def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.

This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.

For example:
input = "unaffable"
output = ["un", "##aff", "##able"]

Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.

Returns:
A list of wordpiece tokens.
"""

text = convert_to_unicode(text)

output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue

is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end

if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens


def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False


def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False


def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False

if __name__ == "__main__":
# give the params: vocab file
tk = FullTokenizer(
"config/vocab.txt")
text = "10009090090 熊猫烧香"
tokens = tk.tokenize(text)
logger.info("tokens: {}".format(tokens))
ids = tk.convert_tokens_to_ids(tokens)
logger.info("ids: {}".format(ids))

执行上述命令,可以得到

1
2
tokens: [u'1000', u'##90', u'##900', u'##90', u'\u718a', u'\u732b', u'\u70e7', u'\u9999']
ids: [8212, 8599, 10589, 8599, 4220, 4344, 4173, 7676]

可以看到对于中文进行wordpiece分词之后得到的是单个字,对于字符串分词之后得到的是多个字符串,在结果中u'##90', u'##900', u'##90' 可以看成这几个分割之后的属于一个此种的几个字,例如原词是“中国”,分词之后是‘##中’,‘##国’

PositionEmbedding

位置嵌入是指将单词的位置信息编码成特征向量,位置嵌入是向模型中引入单词位置关系的至关重要的一环。这里和transformer不同,在transformer中使用的是三角函数,在Bertposition是待学习参数,在Bert的参数定义文件中一开始设置最长的输入序列长度,这样就确定了输入的序列的最的位置的最大值和最小值,如此将token所在的位置按照词向量的方式生成PositionEmbedding

TokenType

词的类别,这个很好理解,在Bert中输入的是大于等于两个句子,那么我们就规定分次之后的token归属于第一个句子那么其tokentype是$0$,如果归属于第二个句子那么tokentype是$1$。 下面是对输入sequence转化为WordEmbedxidgs之后进行后处理的代码,链接为 https://github.com/google-research/bert/blob/master/modeling.py#L428

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def embedding_postprocessor(input_tensor,
use_token_type=False,
token_type_ids=None,
token_type_vocab_size=16,
token_type_embedding_name="token_type_embeddings",
use_position_embeddings=True,
position_embedding_name="position_embeddings",
initializer_range=0.02,
max_position_embeddings=512,
dropout_prob=0.1):
"""Performs various post-processing on a word embedding tensor.
Args:
input_tensor: float Tensor of shape [batch_size, seq_length,
embedding_size].
use_token_type: bool. Whether to add embeddings for `token_type_ids`.
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
Must be specified if `use_token_type` is True.
token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
token_type_embedding_name: string. The name of the embedding table variable
for token type ids.
use_position_embeddings: bool. Whether to add position embeddings for the
position of each token in the sequence.
position_embedding_name: string. The name of the embedding table variable
for positional embeddings.
initializer_range: float. Range of the weight initialization.
max_position_embeddings: int. Maximum sequence length that might ever be
used with this model. This can be longer than the sequence length of
input_tensor, but cannot be shorter.
dropout_prob: float. Dropout probability applied to the final output tensor.
Returns:
float tensor with same shape as `input_tensor`.
Raises:
ValueError: One of the tensor shapes or input values is invalid.
"""
# shape is [batch,sequence_length,embeding_dims]
input_shape = get_shape_list(input_tensor, expected_rank=3)
batch_size = input_shape[0]
seq_length = input_shape[1]
width = input_shape[2]

output = input_tensor

if use_token_type:
if token_type_ids is None:
raise ValueError("`token_type_ids` must be specified if"
"`use_token_type` is True.")
token_type_table = tf.get_variable(
name=token_type_embedding_name,
shape=[token_type_vocab_size, width],
initializer=create_initializer(initializer_range))
# This vocab will be small so we always do one-hot here, since it is always
# faster for a small vocabulary.
flat_token_type_ids = tf.reshape(token_type_ids, [-1])
one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size)
token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
token_type_embeddings = tf.reshape(token_type_embeddings,
[batch_size, seq_length, width])
output += token_type_embeddings

if use_position_embeddings:
assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
with tf.control_dependencies([assert_op]):
full_position_embeddings = tf.get_variable(
name=position_embedding_name,
shape=[max_position_embeddings, width],
initializer=create_initializer(initializer_range))
# Since the position embedding table is a learned variable, we create it
# using a (long) sequence length `max_position_embeddings`. The actual
# sequence length might be shorter than this, for faster training of
# tasks that do not have long sequences.
#
# So `full_position_embeddings` is effectively an embedding table
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
# perform a slice.
position_embeddings = tf.slice(full_position_embeddings, [0, 0],
[seq_length, -1])
num_dims = len(output.shape.as_list())

# Only the last two dimensions are relevant (`seq_length` and `width`), so
# we broadcast among the first dimensions, which is typically just
# the batch size.
position_broadcast_shape = []
for _ in range(num_dims - 2):
position_broadcast_shape.append(1)
position_broadcast_shape.extend([seq_length, width])
position_embeddings = tf.reshape(position_embeddings,
position_broadcast_shape)
output += position_embeddings

output = layer_norm_and_dropout(output, dropout_prob)
return output

通过代码很容易看出来,在BERT中对于tokentypeposition都是学习参数化。

就像上面代码所述

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
"""
Input
"""
output = input_tensor
"""
process token type
"""
output += token_type_embeddings
"""
process position
"""
output += position_embeddings
"""
LN and dropout
"""
output = layer_norm_and_dropout(output, dropout_prob)

将各个部分的embedding叠加在一起然后经过LNdropout,就得到了进入transformer之前输入特征。

数据特点

BERT是一个基础特征提取模型,主要学习的是词向量,具体到我们的下游任务,需要在BERT的基础上继续训练才可以。

自然语料训练

目的:学习模型的通用表达能力

语料:最好是document-level的,比如百科答,wiki,新闻等通用预料,如此能够保证预训练阶段得到的网络权重足够鲁棒以适应后续的特定领域的下游任务。

下游任务训练

目的:学习特定的模型表达能力,比如相关性,问答等子任务

语料:特定任务的标定数据,比如要学习queryrecall的相关性,那么就要准备相应的数据进行训练。

基于BERT的应用

既然BERT这么好用,那么是否能在其他领域也借鉴使用呢,不止我这么想,很多大牛已经尝试过了,由此催生出了videoBert,ELECTRA,uniter,uncoder等这些基于BERT的相关研究。

图文BERT

本人目前在做图片搜索相关工作,复现了uniter这篇论文的代码,由于工作原因,目前代码无法开源,实际上这部分还是很简单的,基于google-bert的代码,加入图片visual部分的处理前处理和后续loss的计算。



本篇论文是将检测和图片描述作为提个一组输入,数据的构成方式和BERT类似,一些区别点是

  1. 第一个句子用图片检测RPN之后得到的ROIPooling特征,第二个句子对应着当前这张图片的描述信息
  2. 使用和BERT相似的训练策略,不过略有不同,在原来MLMNSP基础上加对视觉部分的MRM
  3. MLMMRM是基于条件性的,二者不能同时进行,这个在制作数据的时候要注意
  4. 在数据预处理过程中,视觉部分和文本部分分别处理之后进行concat,再输入给transformer部分
  5. Loss部分在原来的基础上计入了MRM部分的损失,并且论文中针对MRM部分的损失设计了三种方式
    1. 视觉特征回归MRFR
    2. 类别分类MRC
    3. KL散度类别回归MRC-kl

Reference

  1. Official Github: https://github.com/google-research/bert
  2. 知乎: https://zhuanlan.zhihu.com/p/48612853
赏杯咖啡!