[发明专利]一种网络模型蒸馏方法及装置在审
申请号: | 202010055355.6 | 申请日: | 2020-01-17 |
公开(公告)号: | CN111260056A | 公开(公告)日: | 2020-06-09 |
发明(设计)人: | 岳凯宇;邓江帆;周峰 | 申请(专利权)人: | 北京爱笔科技有限公司 |
主分类号: | G06N3/08 | 分类号: | G06N3/08;G06N3/04 |
代理公司: | 北京集佳知识产权代理有限公司 11227 | 代理人: | 柳欣 |
地址: | 100094 北京市海淀区北清路*** | 国省代码: | 北京;11 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 网络 模型 蒸馏 方法 装置 | ||
本申请实施例公开了一种网络模型蒸馏方法及装置,具体地,从第一网络模型(老师模型)的蒸馏位点获取第一通道特征集合,包括M个第一通道特征。同时从第二网络模型(学生模型)的蒸馏位点获取第二通道特征集合,包括N个第二通道特征。按照预设规则及匹配算法从第一通道特征集合中确定出与第二通道特征集合匹配的第三通道特征集合,包括N个通道特征,使得第三通道特征集合与第二通道特征集合匹配。最后,根据第二通道特征集合与第三通道特征集合所匹配的一对通道特征,构建该对通道特征的距离损失函数,利用该距离损失函数对第二网络模型的参数更新,直至构建的距离损失函数满足预设距离阈值,使得第二网络模型学习到第一网络模型的特征表达。
技术领域
本申请涉及自动机器学习技术领域,具体涉及一种网络模型蒸馏方法及装置。
背景技术
卷积神经网络模型蒸馏是一种在广泛使用的小模型训练方法,通常情况况下,小模型具有参数量少、运行速度快、计算资源消耗少的优点,但由于小模型的参数规模较小而存在性能瓶颈、识别准确率不高。模型蒸馏则是使用参数规模较大、性能优异的大模型去引导小模型的训练过程,使后者间接习得前者的特征表达方式,从而达到提升自身性能的目的。
其中,模型蒸馏中最主要的步骤在于训练过程中,在大模型和小模型的特定层级(蒸馏位点)的输出特征之间构建距离损失函数,通过该距离损失函数促使小模型的参数进行迭代更新,进而使得小模型输出的特征表达逼近大模型,以使得小模型的识别准确率提高。
然而,由于大模型和小模型的参数规模不同,导致从大模型选定的特征对应的通道数目与从小模型选定的特征对应的通道数目不对应,因此,在构造距离损失函数时,需要通过额外的转换算子来对大模型的通道数目进行缩减,但这种缩减会引入额外的参数,增加计算开销。
发明内容
有鉴于此,本申请实施例提供一种网络模型蒸馏方法及装置,以实现更为合理有效地使得两个模型之间的通道数据对应,并减小计算开销。
为解决上述问题,本申请实施例提供的技术方案如下:
在本申请实施例第一方面,提供了一种网络模型蒸馏方法,所述方法包括:
从第一网络模型的蒸馏位点获取第一通道特征集合,所述第一网络模型为利用训练样本预先训练生成的老师模型,所述第一通道特征集合包括M个第一通道特征,其中,M为大于1的正整数;
从第二网络模型的蒸馏位点获取第二通道特征集合,所述第二网络模型为学生模型,所述第二通道特征集合包括N个第二通道特征,其中,N为大于1的正整数,M大于N;
根据预设规则以及匹配算法从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,所述第三通道特征集合包括N个通道特征;
针对所述第二通道特征集合和所述第三通道特征集合所匹配的一对通道特征,构建该对通道特征对应的距离损失函数,以根据所述距离损失函数对所述第二网络模型的参数进行更新,直至所述构建的距离损失函数满足预设距离阈值。
在一些可能的实现方式中,所述根据预设规则以及匹配算法从所述第一通道特征集合中确定出与所述第二通道特征集合匹配的第三通道特征集合,包括:
当所述预设规则为稀疏匹配时,计算所述第二通道特征集合中每一个所述第二通道特征与所述第一通道特征集合中每个所述第一通道特征之间的距离,构成第一距离矩阵,所述第一距离矩阵大小为N*M;
对所述第一距离矩阵进行补充操作,添加P个距离数值,以使得补充后的第一距离矩阵大小为M*M,所述P等于M*M减去N*M;
针对所述补充后的距离矩阵中的任一行,选择最小距离数值;
将所述最小距离数值对应的第一通道特征确定为目标通道特征;
将各个所述目标通道特征构成第三通道特征集合。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于北京爱笔科技有限公司,未经北京爱笔科技有限公司许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202010055355.6/2.html,转载请声明来源钻瓜专利网。