[发明专利]一种基于记忆力机制和图神经网络的小样本图像分类方法有效
申请号: | 202110872087.1 | 申请日: | 2021-07-30 |
公开(公告)号: | CN113688878B | 公开(公告)日: | 2022-08-19 |
发明(设计)人: | 张志忠;谢源;刘勋承;田旭东;马利庄 | 申请(专利权)人: | 华东师范大学 |
主分类号: | G06K9/62 | 分类号: | G06K9/62;G06N3/04;G06N3/08;G06V10/764;G06V10/774 |
代理公司: | 上海蓝迪专利商标事务所(普通合伙) 31215 | 代理人: | 徐筱梅;张翔 |
地址: | 200241 *** | 国省代码: | 上海;31 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 基于 记忆力 机制 神经网络 样本 图像 分类 方法 | ||
1.一种基于记忆力机制和图神经网络的小样本图像分类方法,其特征在于利用图神经网络与记忆机制,借助学习好的概念知识帮助小样本模型进行推理预测,具体包括下述步骤:
步骤1:预训练
在整个训练集上学习一个有监督的特征提取器和线性分类器,并将其作为元训练阶段编码器和记忆库的初始化权值;
步骤2:元训练
通过编码器提取支撑集样本和查询集样本的特征,将其作为任务相关节点,所述支撑集样本的特征由构建的记忆库存储;所述记忆库采用更新方式进行优化,以逐步提纯判别性信息,最后从记忆库中挖掘每个类相关信息作为元知识,并通过一个图神经网络来传播任务相关节点以及元知识之间的相似性;
步骤3:元测试
通过任务相关节点和元知识节点得到分类结果,在元测试过程中,记忆库和其他模块不被更新,episode采样训练策略的样本来自测试集。
2.根据权利要求1所述基于记忆力机制和图神经网络的小样本图像分类方法,其特征在于所述步骤1具体包括如下步骤:
1.1:在整个训练集上训练一个有监督的特征提取器和线性分类器;
1.2:使用训练好的特征提取器和线性分类器分别作为元训练阶段编码器和记忆库的初始化权值。
3.根据权利要求1所述基于记忆力机制和图神经网络的小样本图像分类方法,其特征在于所述步骤2具体包括如下步骤:
2.1:采用一个包含支撑集样本和查询集样本的N-Way K-shot T-query任务,通过编码器提取支撑集样本S和查询集样本Q的特征表示作为任务相关的节点
2.2:使用类内均值计算支撑集样本中每个类的中心点fcen∈R[N,d],并将其与存储在记忆库中相同类别的原型点fp∈R[N,d]进行串接,将串接后的特征表示fcat∈R[N,2d]输入到一个全连接层减少维度以提纯语义信息,所述语义信息由下述(1)式进行约束和提纯:
maxI(fp,Y)-βI(fcen,fp) (1);
其中:I(.,.)表示互信息;Y表示标签;β表示拉格朗日系数;
所述记忆库由下述(2)式进行提纯优化:
将提纯后记忆库的特征表示fB∈R[N,d]与记忆库相同类别的原型点由下述(3)式的动量更新对记忆库进一步优化:
fp←λfp+(1-λ)fB (3);
其中:λ是动量系数;
2.3:计算类中心点与记忆库中每个原型点之间的余弦相似性,选择与中心点最近的k个原型点MK={m1,m2,…,mk},将k个原型点都与中心点拼接输入到一个聚合网络,并将k个原型点的信息由下述(4)式进行聚合,其输出作为该类的元知识节点扩充支撑集,作为该类别的伪样本:
其中:[.,.]为拼接操作;f(.;θagg)表示执行一个转换:θagg为R2d→Rd全连接层组成的参数化;aj为下述(5)式表示的mj和fcen[i]的相关性系数:
其中:τ为温度系数;.,.为余弦相似度;
2.4:将任务相关节点和元知识节点一起构造一个全连接的图G=(V,E),其中,每个节点代表一个样本的特征,边表示两个节点的相似性,两个节点来自同一个类则为1,否则为0,将与查询集相连的边由下述(6)式初始化为0.5:
其中:为扩充元知识节点后的支撑集;
2.5:对记忆增强的图神经网络每一层节点特征和边特征进行更新,给定前一层的节点特征和边特征,通过领域聚合过程更新节点特征,所述节点特征由下述(7)式进行更新:
其中:[.,.]为拼接操作;l为记忆增强模块的第l层;fnode(.;θnode)为节点更新网络;θnod为参数化;
所述边特征由下述(8)式重新计算:
2.6:经多层增强的图神经网络的更新,每个查询集节点属于某个类的概率可由下述(9)式计算为所有同类的支撑集节点与查询集节点边的值求和:
其中:δ(yi=Ck)为克罗内克函数,当yi=Ck,值为1,否则为0;
2.7:对记忆库优化目标为下述(10)式所示的最小化二元交叉熵损失函数
其中:ei和分别表示预测的查询集边标签和真实的查询集边标签;λl是第l层的权重系数;BCE表示二元交叉熵损失;
2.8:对记忆库优化目标为下述(11)式所示的另一最小化二元交叉熵损失函数
其中:分别表示预测的元知识边标签和真实的元知识边标签;λl是第l层的权重系数;BCE表示二元交叉熵损失;
2.9:记忆库最终的优化目标为下述(12)式所示的最小化损失函数
其中:α和β为平衡系数,α=0.2,β=0.01。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于华东师范大学,未经华东师范大学许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202110872087.1/1.html,转载请声明来源钻瓜专利网。