分享
手动实现BERT
输入“/”快速插入内容
手动实现
BERT
飞书用户4443
5月24日修改
本文重点介绍了如何从零训练一个
BERT
模型的过程,包括整体上BERT模型架构、数据集如何做预处理、MASK替换策略、训练模型和保存、加载模型和测试等。
一.
BERT
架构
BERT
设计初衷是作为一个通用的backbone,然后在下游接入各种任务,包括翻译任务、分类任务、回归任务等。BERT模型架构如下所示:
1.
输入层
BERT
每次计算时输入两句话。
2.数据预处理
包括移除不能识别的字符、将所有字母小写、多余的空格等。
3.随机将一些词替换为MASK
BERT
模型的训练过程包括两个子任务,其中一个即为预测被遮掩的词的原本的词,所以在计算之前,需要把句子中的一些词替换为MASK交给BERT预测。
4.编码句子
把句子编码成向量,
BERT
同样也有位置编码层,以让处于不同位置的相同的词有不同的向量表示。与Transformer位置编码固定常量不同,BERT位置编码是一个可学习的参数。
5.
编码器
此处的编码器即为Transformer中的编码器,
BERT
使用了Transformer中的编码器来抽取文本特征。
6.预测两个句子的关系
BERT
的计算包括两个子任务,预测两个句子的关系为其中一个子任务,BERT要计算出输入的两个句子的关系,这一般是
二分类
任务。
7.预测MASK词
这是
BERT
的另外一个子任务,要预测出句子中的MASK原本的词。
二.数据集介绍和预处理
1.数据集介绍
数据集使用微软提供的MSR
Paraphrase
数据集进行训练,第1列的数字表示了这2个句子的意思是否相同,2列ID对于训练
BERT
模型没有用处,只需关注第1列和另外2列String。部分样例如下所示: