[发明专利]一种基于多模型融合的特征蒸馏方法、系统、设备和介质在审
申请号: | 202210142194.3 | 申请日: | 2022-02-16 |
公开(公告)号: | CN114462546A | 公开(公告)日: | 2022-05-10 |
发明(设计)人: | 王曦;蹇易 | 申请(专利权)人: | 上海云从企业发展有限公司 |
主分类号: | G06K9/62 | 分类号: | G06K9/62;G06V10/774;G06V10/80;G06V10/74;G06V10/771 |
代理公司: | 上海光华专利事务所(普通合伙) 31219 | 代理人: | 张双凤 |
地址: | 201203 上海市浦东新区中国(上海)自*** | 国省代码: | 上海;31 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 基于 模型 融合 特征 蒸馏 方法 系统 设备 介质 | ||
本发明提出一种基于多模型融合的特征蒸馏方法、系统、设备和介质,包括:通过预训练的多个教师模型分别获取目标数据的特征作为第一特征;通过学生模型的主干网络获取所述目标数据的第二特征,将所述第二特征分别输入多个第一蒸馏子网络,通过每个所述第一蒸馏子网络分别输出与所述第一特征相似度达到设定阈值的第二特征;将所有所述第一特征进行融合得到第一融合特征,并将各所述蒸馏子网络输出的第二特征进行融合得到第二融合特征,将所述第一融合特征和第二融合特征输入第二蒸馏子网络,获取所述目标数据的蒸馏特征;本发明充分利用不同教师模型的优势,从局部和全局两个方向进行蒸馏学习,提升学生模型的识别性能。
技术领域
本发明涉及人工智能领域,尤其涉及一种基于多模型融合的特征蒸馏方法、系统、设备和介质。
背景技术
模型压缩以及知识提取是模型部署中关键的步骤,其中以模型蒸馏为主的训练方法被大家广泛使用,主流的模型蒸馏方法会预先训练一个大模型(教师模型),在分类层计算各个类别的概率,以这个概率分布作为“暗知识”,利用KL散度的距离度量指导小模型(学生模型)学习到大模型的知识.
在人脸识别任务中,此方法面临如下几个问题:人脸识别任务类别数巨大,会造成教师模型中的暗知识矩阵分布过于庞大,不利于学习,甚至十分消耗显存等硬件资源;多个教师模型的特征融合会组成性能更加强大的教师模型,但是不当的训练方式无法充分获得多个教师带来的收益,反而提升了特征的长度,带来计算与存储的负担。
发明内容
鉴于以上现有技术存在的问题,本发明提出一种基于多模型融合的特征蒸馏方法、系统、设备和介质,主要解决现有教师模型暗知识矩阵过于庞大,对硬件要求高,且特征之间存在冗余,不利于学生模型的学习和计算的问题。
为了实现上述目的及其他目的,本发明采用的技术方案如下。
一种基于多模型融合的特征蒸馏方法,包括:
通过预训练的多个教师模型分别获取目标数据的特征作为第一特征;
通过学生模型的主干网络获取所述目标数据的第二特征,将所述第二特征分别输入多个第一蒸馏子网络,通过每个所述第一蒸馏子网络分别输出与所述第一特征相似度达到设定阈值的第二特征;
将所有所述第一特征进行融合得到第一融合特征,并将各所述蒸馏子网络输出的第二特征进行融合得到第二融合特征,将所述第一融合特征和第二融合特征输入第二蒸馏子网络,获取所述目标数据的蒸馏特征。
可选地,所述第一蒸馏子网络包括:注意力模块、归一化层、相似计算层以及至少一个全连接层,
注意力模块根据所述全连接层输出特征的特征值大小获取对应特征的权重输出至所述归一化层;
所述归一化层根据所述全连接层输出特征以及所述注意力模块输出权重完成对应特征归一化;
相似计算层通过预设的损失函数获取归一化后的特征与对应教师模型输出的第一特征之间的相似度。
可选地,所述注意力模块通过映射函数将特征值映射到-1至1之间。
可选地,所述映射函数包括:softmax函数、sigmoid函数。
可选地,所述第二蒸馏子网络与所述蒸馏子网络采用相同的网络结构。
可选地,将所述第一融合特征和第二融合特征输入第二蒸馏子网络之前,还包括:
对所述第一融合特征采用降维算法进行降维处理。
可选地,所述第一蒸馏子网络的数量与所述教师模型的数量相对应,且每个第一蒸馏子网络分别接收一个所述教师模型的第一特征。
一种基于多模型融合的特征蒸馏系统,包括:
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于上海云从企业发展有限公司,未经上海云从企业发展有限公司许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202210142194.3/2.html,转载请声明来源钻瓜专利网。
- 上一篇:一种马铃薯芽眼识别方法及设备
- 下一篇:一种代码生成方法、装置及终端设备