[发明专利]一种基于元学习的领域增量方法在审
申请号: | 202011186818.9 | 申请日: | 2020-10-29 |
公开(公告)号: | CN112308211A | 公开(公告)日: | 2021-02-02 |
发明(设计)人: | 王杰龙;安竹林;程坦;徐勇军 | 申请(专利权)人: | 中国科学院计算技术研究所厦门数据智能研究院 |
主分类号: | G06N3/04 | 分类号: | G06N3/04;G06N3/08;G06K9/62 |
代理公司: | 厦门致群财富专利代理事务所(普通合伙) 35224 | 代理人: | 刘兆庆 |
地址: | 361000 福建省*** | 国省代码: | 福建;35 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 基于 学习 领域 增量 方法 | ||
本发明公开了一种基于元学习的领域增量方法,包括如下步骤:S1、构建预训练模型;S2、用预训练模型训练旧模型;S3、训练新模型。本发明采用随机保留的5%的记忆数据和新数据混合微调训练新模型,同时还采用交叉熵损失函数和知识蒸馏损失函数联合指导新模型的学习,使其在记住旧领域的分类知识的同时,学习新领域数据的分类知识,大大减少数据存储和训练时间的开销。
技术领域
本发明涉及计算机技术领域,特别涉及一种基于元学习的领域增量方法。
背景技术
随着深度学习的兴起,基于卷积神经网络的物体分类方法飞速发展,识别准确率得到很大提高。然而,基于卷积神经网络的方法也有缺陷:当测试的图片数据分布与训练图片数据分布不一致时,如光照、背景、姿态等发生变化,模型的准确率会下降。因此,当出现新的领域数据,即与原来的训练数据分布不一致的数据时,需要模型能够增量地学习新的领域分类,也就是在记住旧领域数据的分类同时学习新的领域知识分类。
目前,最直观的领域增量学习方法是用新领域的数据继续训练模型,但该方法往往存在精度不能满足要求的情况:如果训练不足,对于新领域数据准确率不高;如果训练过度,对于旧领域数据准确率则会下降,二者难以调和。而如果将旧领域数据和新领域数据直接混合重新训练卷积神经网络,则数据存储和训练时间都开销巨大,尤其是实际中随着新领域数据越来越多,开销将越来越大。因此,找到一种能够以低开销代价,获得高精度性能的领域增量识别方法尤为重要。
发明内容
为解决上述问题,本发明提供了一种基于元学习的领域增量方法。
本发明采用以下技术方案:
一种基于元学习的领域增量方法,包括如下步骤:
S1、构建预训练模型:利用元学习方法iTAML,选择若干公开的数据集作为元数据,构造元任务并学习一个预训练模型,获得所述预训练模型的参数φ,所述预训练模型为卷积神经分类网络;
S2、用预训练模型训练旧模型:构建一个与所述预训练模型相同类型的分类模型作为旧模型,将所述预训练模型的参数φ导入所述旧模型,并使用交叉熵损失函数指导旧数据Dold训练所述旧模型,训练结束后,随机采样保留5%的旧数据Dold作为记忆数据Dmemory;
S3、训练新模型:用所述记忆数据Dmemory和新数据Dnew一起混合训练所述旧模型,对于新数据Dnew使用交叉熵损失函数指导模型学习,对于记忆数据Dmemory使用交叉熵损失函数和知识蒸馏损失函数联合指导模型学习,从而得到新模型。
进一步地,所述卷积神经分类网络为VGG、ResNet、MobileNet、DenseNet或SENet中的一种。
进一步地,步骤S1中所述元学习方法iTAML的训练过程为增量式,共训练T个阶段,T为总任务数,t表示第t个任务;
当t=1,用交叉熵损失公式正常训练任务1的数据,得到预训练模型参数φ1,所述交叉熵损失公式如下:
其中,Dt表示属于第t个任务的数据集,共有N个样本,xi为其中一个,pi表示模型对xi的预测值,yi表示真实标签值;
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于中国科学院计算技术研究所厦门数据智能研究院,未经中国科学院计算技术研究所厦门数据智能研究院许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202011186818.9/2.html,转载请声明来源钻瓜专利网。
- 上一篇:一种基于主动学习的样本标注方法
- 下一篇:一种整车制动系统的控制方法