[发明专利]模型训练方法、文本分类方法、电子设备及存储介质在审
申请号: | 202010861649.8 | 申请日: | 2020-08-24 |
公开(公告)号: | CN111898707A | 公开(公告)日: | 2020-11-06 |
发明(设计)人: | 刘小康;李健铨;晋耀红 | 申请(专利权)人: | 鼎富智能科技有限公司 |
主分类号: | G06K9/62 | 分类号: | G06K9/62;G06F16/35;G06N3/04;G06N3/08 |
代理公司: | 北京超凡宏宇专利代理事务所(特殊普通合伙) 11463 | 代理人: | 蒋姗 |
地址: | 230000 安徽省合肥市*** | 国省代码: | 安徽;34 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 模型 训练 方法 文本 分类 电子设备 存储 介质 | ||
本申请提供一种模型训练方法、文本分类方法、电子设备及存储介质。方法包括根据学生模型第二transformer层中的各层到老师模型第一transformer层中各层之间的距离以及搬运量,对第二transformer层中每一层对应的权重进行更新获得更新后的权重;利用更新后的权重计算第二transformer层与第一transformer层之间的EMD,获得第一蒸馏损失;计算第二embedding层的第二蒸馏损失以及第二prediction层的第三蒸馏损失;根据第一蒸馏损失、第二蒸馏损失和第三蒸馏损失对学生模型中的参数进行训练,获得训练后的学生模型。本申请能够使学生模型学习到老师模型中更多的知识。
技术领域
本申请涉及自然语言处理技术领域,具体而言,涉及一种模型训练方法、文本分类方法、电子设备及存储介质。
背景技术
随着深度学习的发展,自然语言处理中深度神经网络的使用越来越多,为了能够提高模型的性能,大多数模型都比较复杂,参数量大,内存消耗大的问题,很难直接应用于GPU及智能手机等应用资源受限的设备上。
模型蒸馏方法较好的解决了上述问题,现有技术中,模型蒸馏方法大多是以下几种思路:
(1)将老师模型的输出(soft target)作为学生模型的学习目标,通过最小化学生模型与老师模型的差距对学生模型进行更新;也有将soft target和真实目标(hardtarget)以线性方式结合进行计算的。
(2)大多蒸馏方法是在最后的输出层上对知识进行学习,还有一些方法也在权重矩阵和隐藏的激活层上一起进行学习。
(3)学生模型的结构可以是LSTM等其他各种推理速度更快的架构,也可以是老师模型的缩小版,比如更小的网络层数,更小的隐含层维度。
(4)目标函数可以是交叉熵、KL散度、均方误差等。
上述的模型蒸馏方法中,是将学生模型中的每一层分别与老师模型中某一层相近似,例如:老师模型有12层,学生模型有4层,通过老师模型的第2层来学习学生模型的第1层,通过老师模型的第4层来生成学生模型的第2层,通过老师模型的第6层来生成学生模型的第3层,通过老师模型的第9层来生成学生模型的第4层。这种方式获得学生模型从老师模型中学习到的知识有限。
发明内容
本申请实施例的目的在于提供一种模型训练方法、文本分类方法、电子设备及存储介质,用以解决学生模型从老师模型中学习到的知识有限的问题。
第一方面,本申请实施例提供一种模型训练方法,获取老师模型和学生模型;其中,所述老师模型为训练好的模型,包括第一向量embedding层、第一转换器transformer层和第一预测prediction层,且所述第一transformer层包括M层;所述学生模型为待训练的模型,包括第二embedding层、第二transformer层和第二prediction层,且所述第二transformer层包括N层;M和N均为正整数,且M>N;获取训练样本,并根据训练样本获得第二transformer层中的第i层到第一transformer层中的第j层之间的距离;其中,i和j均为正整数,且i≤N,j≤M;根据第二transformer层中的第i层到第一transformer层中的第j层之间的距离以及搬运量,对所述第二transformer层中每一层对应的权重进行更新,获得对应层的更新后的权重;利用所述更新后的权重计算第二transformer层与第一transformer层之间的搬土距离EMD,获得第一蒸馏损失;分别计算第一embedding层的输出与第二embedding层的输出之间的第二蒸馏损失,以及第一prediction层的输出与第二prediction层的输出之间的第三蒸馏损失;根据所述第一蒸馏损失、所述第二蒸馏损失和所述第三蒸馏损失对所述学生模型中的参数进行训练,获得训练后的学生模型。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于鼎富智能科技有限公司,未经鼎富智能科技有限公司许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202010861649.8/2.html,转载请声明来源钻瓜专利网。