[发明专利]一种基于对抗训练的端到端任务型对话学习框架和方法有效
申请号: | 202011299935.6 | 申请日: | 2020-11-19 |
公开(公告)号: | CN112541060B | 公开(公告)日: | 2021-08-13 |
发明(设计)人: | 何万伟;杨敏;李成明;姜青山 | 申请(专利权)人: | 中国科学院深圳先进技术研究院 |
主分类号: | G06F16/332 | 分类号: | G06F16/332;G06N3/04;G06N3/08;G06F40/216 |
代理公司: | 北京市诚辉律师事务所 11430 | 代理人: | 耿慧敏 |
地址: | 518055 广东省深圳*** | 国省代码: | 广东;44 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 基于 对抗 训练 端到端 任务 对话 学习 框架 方法 | ||
1.一种基于对抗训练的端到端任务型对话学习方法,包括:
预训练学生网络、第一教师网络和第二教师网络,其中,第一教师网络用于从知识库中检索实体词;第二教师网络用于从对话数据中学习语言模式,学生网络用于学习提取实体词并生成回复语句;
训练生成对抗网络以从第一教师网络和第二教师网络向学生网络迁移知识,该生成对抗网络包含第一判别器、第二判别器和生成器,其中学生网络作为生成器,经对抗式训练产生对话回复语句,第一判别器用于区分学生网络和第一教师网络产生的输出分布,第二判别器用于区分学生网络和第二教师网络产生的输出分布;
将经对抗式训练的学生网络用于任务型对话。
2.根据权利要求1所述的方法,其中,在预训练过程中,采用自我批判序列算法优化第一教师网络和第二教师网网络,第二教师网络的损失函数设置为:
第一教师网络的损失函数设置为:
其中是第一教师网络的输出分布,ys是随机采样的输出,是第二教师网络的输出分布,是每个解码时间步从相应输出分布中获得概率最大的词,BLEU表示布勒分值,F1表示实体F1分值。
3.根据权利要求1所述的方法,其中,在预训练过程中,学生网络的损失函数设置为输出分布和真实的目标词汇yt之间的交叉熵,表示为:
其中,T是输出回复语句的长度。
4.根据权利要求1所述的方法,其中,第一判别器和第二判别器是采用门控循环单元的二分类器,对于每个判别器,采用门控循环单元来编码输出分布Po,门控循环单元的最后一个隐藏层状态hT被传递给输出层,在每个时间步t考虑当前的输出分布时,判别器定义为:
D(Po)=sigmoid(WhT)
其中W是可学习的参数,ht是在第t个时间步的隐藏层状态,BiGRU表示双向门控循环单元。
5.根据权利要求3所述的方法,其中,在训练生成对抗网络过程中,当训练第一判别器DKB和第二判别器DDP时,固定学生网络的参数不变,第一判别器和第二判别器被训练用于最小化分配错误标签给学生网络和相应教师网络的输出分布的概率,表示为LDP和LKB:
LDP=-log(1-DDP(Ps))-log(DDP(Pdp))
LKB=-log(1-DKB(Ps))-log(DKB(Pkb))
LD=LDP+LKB
其中,Ps是学生网络产生的输出分布,Pkb是第一教师网络产生的输出分布,Pdp是第二教师网络产生的输出分布,LD是用于更新判别器的目标函数。
6.根据权利要求5所述的方法,其中,在训练生成对抗网络过程中,当更新学生网络时,以最小化对抗损失函数LG为目标,表示为:
LG=-log(DDP(Ps))-log(DKB(Ps))
学生网络最终的损失函数表示为:
其中α是一个标量,用于决定在学生网络中对抗损失LG的重要程度。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于中国科学院深圳先进技术研究院,未经中国科学院深圳先进技术研究院许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202011299935.6/1.html,转载请声明来源钻瓜专利网。