[发明专利]用于知识蒸馏的网络训练方法、装置、介质与电子设备有效
申请号: | 201910923038.9 | 申请日: | 2019-09-27 |
公开(公告)号: | CN110674880B | 公开(公告)日: | 2022-11-11 |
发明(设计)人: | 田野 | 申请(专利权)人: | 北京迈格威科技有限公司 |
主分类号: | G06K9/62 | 分类号: | G06K9/62;G06N3/04 |
代理公司: | 北京律智知识产权代理有限公司 11438 | 代理人: | 王辉;阚梓瑄 |
地址: | 100190 北京市海淀区科*** | 国省代码: | 北京;11 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 用于 知识 蒸馏 网络 训练 方法 装置 介质 电子设备 | ||
1.一种用于知识蒸馏的网络训练方法,其特征在于,包括:
将样本数据输入教师网络,获得所述样本数据对应的软标签数据,将所述样本数据输入学生网络,获得所述样本数据对应的预测数据;
基于所述预测数据、所述软标签数据和所述样本数据对应的硬标签数据,构建损失函数;
根据所述损失函数更新所述教师网络中的参数和所述学生网络中的参数;
其中,所述教师网络和所述学生网络用于图像分类,所述样本数据包括样本图片,所述硬标签数据包括所述样本图片的分类标签,所述软标签数据包括通过所述教师网络识别所述样本图片中存在目标对象的概率数据,所述预测数据包括通过所述学生网络识别所述样本图片中存在目标对象的概率数据;
所述损失函数为:
其中,L为所述损失函数,i表示所述硬标签数据的类别,yi为第i类硬标签数据,为第i类硬标签数据对应的预测数据,为第i类硬标签数据对应的软标签数据;∈为经验参数,min(yi)∈max(yi);a、b、c均为非负的权重参数,b不为0,且a和c中至少一个不为0。
2.根据权利要求1所述的方法,其特征在于,所述基于所述预测数据、所述软标签数据和所述样本数据对应的硬标签数据,构建损失函数,包括:
根据所述预测数据和所述硬标签数据,构建第一子损失;
根据所述预测数据和所述软标签数据,构建第二子损失;
根据所述第一子损失和所述第二子损失,确定所述损失函数。
3.根据权利要求2所述的方法,其特征在于,所述样本数据包括正样本;所述根据所述预测数据和所述软标签数据,构建第二子损失,包括:
根据所述正样本对应的预测数据和所述正样本对应的软标签数据,构建所述第二子损失。
4.根据权利要求3所述的方法,其特征在于,所述根据所述损失函数更新所述教师网络中的参数和所述学生网络中的参数,包括:
根据所述损失函数和所述正样本对应的预测数据,更新所述学生网络中的参数;
根据所述损失函数和所述正样本对应的软标签数据,更新所述教师网络中的参数。
5.根据权利要求4所述的方法,其特征在于,所述正样本对应的预测数据包括对所述正样本的学生预测值和所述学生预测值对应的概率;所述根据所述损失函数和所述正样本对应的预测数据,更新所述学生网络中的参数,包括:
根据所述损失函数对所述学生预测值的梯度,更新所述学生网络中的参数,使所述学生预测值对应的概率趋近于1。
6.根据权利要求5所述的方法,其特征在于,所述正样本对应的软标签数据包括对所述正样本的教师预测值和所述教师预测值对应的概率,所述根据所述损失函数对所述学生预测值的梯度,更新所述学生网络中的参数,使所述学生预测值对应的概率趋近于1,包括:
根据所述损失函数对所述学生预测值的梯度,以及所述学生预测值和所述教师预测值之间的误差,更新所述学生网络中的参数,使所述学生预测值对应的概率趋近于1和所述教师预测值对应的概率。
7.根据权利要求5所述的方法,其特征在于,所述正样本对应的软标签数据包括对所述正样本的教师预测值和所述教师预测值对应的概率,所述根据所述损失函数和所述正样本对应的软标签数据,更新所述教师网络中的参数,包括:
根据所述损失函数对所述教师预测值的梯度,更新所述教师网络中的参数,使所述教师预测值对应的概率趋近于1。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于北京迈格威科技有限公司,未经北京迈格威科技有限公司许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/201910923038.9/1.html,转载请声明来源钻瓜专利网。