[发明专利]基于知识蒸馏的信息检索方法有效
申请号: | 202110534072.4 | 申请日: | 2021-05-17 |
公开(公告)号: | CN113312548B | 公开(公告)日: | 2022-05-03 |
发明(设计)人: | 鲁伟明;朱堂灿;庄越挺 | 申请(专利权)人: | 浙江大学 |
主分类号: | G06F16/9535 | 分类号: | G06F16/9535;G06N5/02;G06N20/10 |
代理公司: | 杭州求是专利事务所有限公司 33200 | 代理人: | 刘静 |
地址: | 310058 浙江*** | 国省代码: | 浙江;33 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 基于 知识 蒸馏 信息 检索 方法 | ||
1.一种基于知识蒸馏的信息检索方法,其特征在于,包括以下步骤:
1)训练教师模型:基于交叉熵损失函数,利用训练集T来训练教师模型;具体步骤为,
训练集T为其中Qi表示查询,pi和ni为正负例,N为总的查询数量;首先,选择教师模型为BERT-CAT模型,则教师模型计算查询Q与段落d之间相关性的评分公式为:
Teacher(Q,d)=BERT-CAT(Q,d)=BERT([CLS;Q;SEP;d])1*W
其中,BERT是一种基于Transformer的双向编码表示语言模型,CLS和SEP表示BERT中的特殊词条,“;”表示拼接操作,下标1表示取CLS词条,W表示一个权重矩阵;
之后,对训练集T中每个查询及其所对应正例和负例的三元组,使用该教师模型计算正例得分Pi以及负例得分Ni:
Pi=Teacher(Qi,pi)
Ni=Teacher(Qi,ni)
再通过正负例得分计算相应的交叉熵损失:
最后通过最小化交叉熵损失来优化教师模型,训练得到最终的教师模型;
2)训练集段落重排序:使用步骤1)训练后的教师模型,对训练集Told中每个查询所对应的段落集进行相关性重排序,得到排序πT,并用重排序后的段落集构建新训练集Tnew;具体步骤为,
利用教师模型对训练集Told进行重排序;
基于步骤1)所训练的教师模型Teacher,对于训练集Told中每个查询Q所对应的一个段落集D={d1,d2,...,dl},使用模型Teacher对所有段落进行相对于查询Q的打分:
S=Teacher(Q,D)={s1,s2,...,sl}
其中,si=Teacher(Q,di),之后根据每个段落得分的高低对所有段落进行重排序,得到一个新的有序的段落集Dr={dr1,dr2,...,drl},其中sr1>sr2>…>srl,所有查询对应的有序段落集构成新训练集Tnew;
3)训练学生模型:利用训练集T,计算学生模型的交叉熵损失L1;然后,利用学生模型,对训练集Tnew中每个查询所对应的段落集进行相关性重排序,得到排序πS,再利用列表置换损失函数计算πT与πS之间的差异损失L2;最后用L1和L2的加权和作为学生模型的最终损失L,并通过最小化L来训练学生模型;具体步骤为;
首先,选择BERT-DOT模型和ColBERT模型作为学生模型Student;
BERT-DOT模型计算查询Q与段落d之间相关性的评分公式为:
rq=BERT([CLS;Q])1*W
rd=BERT([CLS;d])1*W
BERT-DOT(Q,d)=rq·rd
其中,BERT是一种基于Transformer的双向编码表示语言模型,CLS表示特殊词条,“;”表示拼接操作,下标1表示取CLS词条,W表示一个权重矩阵,·表示内积运算;
ColBERT模型计算查询Q与段落d之间相关性的评分公式为:
rq=BERT([CLS;Q;rep(MASK)])1*W
rd=BERT([CLS;d])1*W
其中,BERT是一种基于Transformer的双向编码表示语言模型,CLS表示特殊词条,“;”表示拼接操作,rep(MASK)表示多个MASK词条拼接而成的词条集,下标1表示取CLS词条,W表示一个权重矩阵,·表示内积运算;
之后,对训练集T中每个查询及其所对应正例和负例的三元组,使用学生模型计算正例得分Pi以及负例得分Ni:
Pi=Student(Q,pi)
Ni=Student(Q,ni)
其中,Student代表BERT-DOT模型和ColBERT模型;
之后通过正负例得分计算相应的交叉熵损失:
接着计算重排序序列的列表置换损失函数;
根据步骤2)所得的重排序段落训练集Tnew,对于每个查询Q,有对应的重排序后的段落集Dr={dr1,dr2,...,drl},段落di相对于查询Q使用教师模型所得到的分数si,满足sr1>sr2>…>srl;
使用学生模型重新计算所有段落相对于查询Q的得分,得到一个新的分数列表:
S′=Student(Q,Dr)={s′r1,s′r2,...,s′rl}
根据该列表,得到查询置换的概率:
之后最大化每个查询置换概率的对数似然,即最小化列表置换损失函数:
最后,将两部分损失加权求和作为模型的损失:
Loss=Loss1+αLoss2
其中,α为权重参数;
4)利用学生模型进行信息检索:利用学生模型计算用户查询所对应的段落的评分,将评分最高的段落作为查询答案。
2.根据权利要求1所述的一种基于知识蒸馏的信息检索方法,其特征在于,所述步骤4)具体为:
利用学生模型进行信息检索;
在步骤3)训练得到学生模型后,使用该学生模型对测试集中相应查询所对应的段落集进行重排序,获取排行最高的段落作为查询答案,以此来测试模型的效果;
对于用户给定的问题,在语料库中初步筛选出相应段落,再用学生模型计算段落相对于问题的得分,根据得分的高低将相应用户所需要的答案量的答案提供给用户。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于浙江大学,未经浙江大学许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202110534072.4/1.html,转载请声明来源钻瓜专利网。
- 信息记录介质、信息记录方法、信息记录设备、信息再现方法和信息再现设备
- 信息记录装置、信息记录方法、信息记录介质、信息复制装置和信息复制方法
- 信息记录装置、信息再现装置、信息记录方法、信息再现方法、信息记录程序、信息再现程序、以及信息记录介质
- 信息记录装置、信息再现装置、信息记录方法、信息再现方法、信息记录程序、信息再现程序、以及信息记录介质
- 信息记录设备、信息重放设备、信息记录方法、信息重放方法、以及信息记录介质
- 信息存储介质、信息记录方法、信息重放方法、信息记录设备、以及信息重放设备
- 信息存储介质、信息记录方法、信息回放方法、信息记录设备和信息回放设备
- 信息记录介质、信息记录方法、信息记录装置、信息再现方法和信息再现装置
- 信息终端,信息终端的信息呈现方法和信息呈现程序
- 信息创建、信息发送方法及信息创建、信息发送装置