[发明专利]一种基于元学习的快速适应DRBM方法在审

专利信息
申请号: 202110134999.9 申请日: 2021-01-29
公开(公告)号: CN112862094A 公开(公告)日: 2021-05-28
发明(设计)人: 张新禹;刘子衿;任祖煜;霍凯;刘振;张双辉;刘永祥;姜卫东;黎湘;卢哲俊 申请(专利权)人: 中国人民解放军国防科技大学
主分类号: G06N3/08 分类号: G06N3/08;G06N3/04;G06F17/18
代理公司: 湖南企企卫知识产权代理有限公司 43257 代理人: 任合明
地址: 410073 湖*** 国省代码: 湖南;43
权利要求书: 查看更多 说明书: 查看更多
摘要:
搜索关键词: 一种 基于 学习 快速 适应 drbm 方法
【权利要求书】:

1.一种基于元学习的快速适应DRBM方法,其特征在于,该方法分为以下步骤:

S1.建立DRBM网络结构:DRBM的网络结构可以分为三层——可见层、隐藏层和分类层,每层包含若干个神经元,神经元的连接方式是同一层内部的节点之间没有任何连接,而层与层之间的节点互相以全连接的方式相互连接在一起;每个神经元的状态均为1或者0的二元取值,1表示激活,0表示未激活,激活意味着该神经元所代表的节点对数据进行了处理;DRBM的分布由神经元的值确定,其中可见层用于表示输入数据,可见层节点个数由输入数据维度决定,可见层节点取值为输入数据各维的取值;隐藏层按照最优化的方式获取观察数据的某种统计意义上的特征,隐藏层节点个数根据数据和任务不同人为进行调整;分类层单元根据隐藏层单元提取出的数据特征进行类别判定,分类层节点个数由数据类别数量决定;

DRBM网络由网络参数进行描述;假设可见层节点个数为l,隐藏层节点个数为m,分类层节点个数为n,可见层偏置向量为b,b为1行l列向量、隐藏层偏置向量为c,c为1行m列向量、分类层偏置向量为d,d为1行n列向量,输入层和隐藏层的权重矩阵为W,W为l行m列矩阵、分类层和隐藏层的权重矩阵为U,U为n行m列矩阵;设向量θ=(W,U,b,c,d),则训练DRBM网络的目的就是寻找最佳的θ值,来通过网络预测数据类别;

DRBM是一种基于能量函数确定的模型,其能量函数可以被定义为:

E(y,x,h)=-hWTxT-bxT-chT-deyT-hUTyT (1)

其中x表示可见层的状态向量,x为1行l列向量、h表示隐藏层单元的状态向量,h为1行m列向量、y表示分类层的状态向量,y为1行n列向量,y是标签的“独热”型表示,即所有节点中只有一个节点为1,其余节点均为0;x、h、y的联合概率分布为:

其中称为配分函数;

S2元学习阶段:

S2.1网络参数初始化:

初始化可见层的偏置向量b为1行l列的零矩阵、隐藏层的偏置向量c为1行m列的零矩阵、分类层的偏置向量d为1行n列的零矩阵,以及对应的梯度Δb为1行l列的零矩阵、Δc为1行m列的零矩阵、Δd为1行n列的零矩阵,可见层与隐藏层的偏置矩阵W为l行m列的零矩阵和分类层与隐藏层的偏置矩阵U为n行m列的零矩阵,以及对应的梯度ΔW为l行m列的零矩阵、ΔU为n行m列的零矩阵,网络参数θ初始值记作θ0;设置内部学习率α为0.01~0.5、外部学习率β为0.005~0.05、动量学习率m=0.5和惩罚系数p=10-4

S2.2完成一个任务的训练:

元学习阶段的训练以一个任务为基本单元,每个任务包含两个部分——支撑集和质询集;利用支撑集进行训练的过程称为内部学习,利用质询集进行训练的过程称为外部学习;每个任务的数据类别与其他任务可以相同也可以不同;元学习阶段的所有任务共同组成了训练任务;训练任务中的所有样本都既含有数据信息也含有标签信息;具体步骤如下:

S2.2.1将支撑集作为网络输入,θ0作为网络参数初始值,完成一次训练:

S2.2.1.1计算隐藏层的概率分布函数:

p(h|y,x)=sigmoid(x(0)W+y(0)U+c) (3)

其中x(0)为输入数据,y(0)为独热形式的输入数据标签,

S2.2.1.2得到隐藏层概率分布函数后,利用Gibbs采样得到隐藏层节点取值;

S2.2.1.3根据隐藏层节点取值重构可见层和分类层,分别计算可见层和分类层的概率分布函数:

其中c′表示c的所有可能取值;

S2.2.1.4求得可见层和分类层概率分布函数后,利用Gibbs采样x(1)~p(x|h)和y(1)~p(y|h)得到可见层和分类层的节点取值x(1)、y(1)

S2.2.1.5根据可见层和分类层的节点取值x(1)、y(1)再次计算隐藏层概率分布函数:

p(h|y,x)=sigmoid(x(1)W+y(1)U+c) (5)

并通过Gibbs采样得到h(1)~p(h|y,x);

S2.2.1.6根据x(0)、y(0)、x(1)、y(1)、h(0)、h(1),求得网络参数θ的更新梯度:

S2.2.1.7输入训练任务集中的任务,根据内部学习率α、动量学习率m和惩罚系数p对梯度进行修正:

其中i∈[1,ns],ns为支撑集中的样本数量;

S2.2.1.8根据梯度对网络参数θ进行更新:

输入支撑集数据进行训练更新后的网络参数记为θns

S2.2.2将质询集作为网络输入,θs作为网络参数初始值,完成一次训练:

按照公式(3)、(4)、(5)、(6),依次完成概率分布函数计算、节点分布采样以及网络参数更新梯度计算,根据公式(9)对修正梯度进行计算:

其中β为外部学习率,i∈[1,nq],nq为质询集中的样本数量;最后根据公式(8)完成网络参数更新,得到的网络参数记为θnq

在完成一个任务的训练后,对网络参数进行一次更新,并且更新只保留外部学习部分,即:

θt+1=θt+(θnqns) (10)

其中,t∈[1,nt],t表示第t个训练任务,nt表示训练任务数量,保存网络参数θt+1,作为下一个任务的网络参数初始值,一个任务的训练结束;

S2.3完成所有遍历:

每一次遍历需要训练若干个任务,任务个数根据数据集大小决定;当完成一个任务的训练后,将更新后的网络参数初始值作为下一个任务的网络参数初始值,依次重复S2.2的过程,直至所有任务完成一次训练,这就是一次遍历;完成一次遍历后,将更新后的网络参数作为下一次遍历的网络参数初始值,直至完成所有遍历,最终得到的网络参数记为θnt

S3模型学习阶段:

S3.1将支撑集作为网络输入,θnt作为网络参数初始值,完成一次训练:

按照公式(3)、(4)、(5)、(6),依次完成概率分布函数计算、节点分布采样以及网络参数更新梯度的计算,再按照公式(11)对修正梯度进行计算:

其中i∈[1,ts],ts为测试任务的支撑集中的样本数量,最后根据公式(8)完成网络参数更新;

S3.2完成所有遍历:

每一次遍历就是对测试任务中的支撑集完成一次训练,完成一次遍历后,将更新后的网络参数作为下一次遍历的网络参数初始值,直至完成所有遍历,最终得到的网络参数记为θt

S3.3将质询集中的数据输入网络,θt作为网络参数,依次计算出第i个类别的预测概率:

prediction(i)=repeat(d(i),tq)+log(exp(x(0)·W+T(i)·U+c)+1) (12)

其中T为tq行,nc列的矩阵,tq为质询集样本数量,nc为总类别数,T(i)表示偏置向量T除第i列为1外其余列全为0;d(i)表示偏置向量d除第i列为1外其余列全为0;repeat(d(i),tq)表示将偏置向量d(i)重复tq次,变为tq行,n列的矩阵;

计算完nc个类别的预测概率后,取其中最大值所在列,作为类别预测结果,完成目标分类。

下载完整专利技术内容需要扣除积分,VIP会员可以免费下载。

该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于中国人民解放军国防科技大学,未经中国人民解放军国防科技大学许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服

本文链接:http://www.vipzhuanli.com/pat/books/202110134999.9/1.html,转载请声明来源钻瓜专利网。

×

专利文献下载

说明:

1、专利原文基于中国国家知识产权局专利说明书;

2、支持发明专利 、实用新型专利、外观设计专利(升级中);

3、专利数据每周两次同步更新,支持Adobe PDF格式;

4、内容包括专利技术的结构示意图流程工艺图技术构造图

5、已全新升级为极速版,下载速度显著提升!欢迎使用!

请您登陆后,进行下载,点击【登陆】 【注册】

关于我们 寻求报道 投稿须知 广告合作 版权声明 网站地图 友情链接 企业标识 联系我们

钻瓜专利网在线咨询

周一至周五 9:00-18:00

咨询在线客服咨询在线客服
tel code back_top