[发明专利]一种基于元学习的领域增量方法在审
申请号: | 202011186818.9 | 申请日: | 2020-10-29 |
公开(公告)号: | CN112308211A | 公开(公告)日: | 2021-02-02 |
发明(设计)人: | 王杰龙;安竹林;程坦;徐勇军 | 申请(专利权)人: | 中国科学院计算技术研究所厦门数据智能研究院 |
主分类号: | G06N3/04 | 分类号: | G06N3/04;G06N3/08;G06K9/62 |
代理公司: | 厦门致群财富专利代理事务所(普通合伙) 35224 | 代理人: | 刘兆庆 |
地址: | 361000 福建省*** | 国省代码: | 福建;35 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 基于 学习 领域 增量 方法 | ||
1.一种基于元学习的领域增量方法,其特征在于:包括如下步骤:
S1、构建预训练模型:利用元学习方法iTAML,选择若干公开的数据集作为元数据,构造元任务并学习一个预训练模型,获得所述预训练模型的参数φ,所述预训练模型为卷积神经分类网络;
S2、用预训练模型训练旧模型:构建一个与所述预训练模型相同类型的分类模型作为旧模型,将所述预训练模型的参数φ导入所述旧模型,并使用交叉熵损失函数指导旧数据Dold训练所述旧模型,训练结束后,随机采样保留5%的旧数据Dold作为记忆数据Dmemory;
S3、训练新模型:用所述记忆数据Dmemory和新数据Dnew一起混合训练所述旧模型,对于新数据Dnew使用交叉熵损失函数指导模型学习,对于记忆数据Dmemory使用交叉熵损失函数和知识蒸馏损失函数联合指导模型学习,从而得到新模型。
2.如权利要求1所述的一种基于元学习的领域增量方法,其特征在于:所述卷积神经分类网络为VGG、ResNet、MobileNet、DenseNet或SENet中的一种。
3.如权利要求1所述的一种基于元学习的领域增量方法,其特征在于:步骤S1中所述元学习方法iTAML的训练过程为增量式,共训练T个阶段,T为总任务数,t表示第t个任务;
当t=1,用交叉熵损失公式正常训练任务1的数据,得到预训练模型参数φ1,所述交叉熵损失公式如下:
其中,Dt表示属于第t个任务的数据集,共有N个样本,xi为其中一个,pi表示模型对xi的预测值,yi表示真实标签值;
当t≥2,则初始化参数为上一阶段训练好的参数φbase=φt-1,分别取出任务1、任务2、...、任务t,共t个任务数据以φbase为初始参数用交叉熵损失更新优化,得到对应任务的临时参数φ1,φ2,…φt,然后更新φbase,当损失不下降时获得该阶段的最终结果参数φt=φbase,所述更新φbase时采用如下公式:
最终,以获得的φT作为预训练模型的参数。
4.如权利要求1所述的一种基于元学习的领域增量方法,其特征在于:步骤S3中所述使用交叉熵损失函数和知识蒸馏损失函数联合指导模型学习时采用的是整体损失,其公式为:Loss=loss_ce+loss_distill,其中loss_ce表示交叉熵损失,loss_distill表示知识蒸馏损失,
而loss_ce的求解公式如下:
其中,xi∈Dmemory∪Dnew表示属于记忆数据或新数据的样本,共有N个,pi表示模型对xi的预测值,yi表示真实标签值;
而loss_distill的求解公式如下:
其中,xi∈Dmemory表示属于记忆数据的样本,共有N个,qi是旧模型关于数据xi的预测值,pi是训练中模型对xi的预测值。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于中国科学院计算技术研究所厦门数据智能研究院,未经中国科学院计算技术研究所厦门数据智能研究院许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202011186818.9/1.html,转载请声明来源钻瓜专利网。
- 上一篇:一种基于主动学习的样本标注方法
- 下一篇:一种整车制动系统的控制方法