[发明专利]一种基于联邦学习与多任务学习的模型训练方法有效
申请号: | 202011194414.4 | 申请日: | 2020-10-30 |
公开(公告)号: | CN112348199B | 公开(公告)日: | 2022-08-30 |
发明(设计)人: | 谢在鹏;陈瑞锋;叶保留;朱晓瑞;屈志昊;徐媛媛 | 申请(专利权)人: | 河海大学 |
主分类号: | G06N20/00 | 分类号: | G06N20/00;G06N3/04;G06N3/08;G06F9/50 |
代理公司: | 南京经纬专利商标代理有限公司 32200 | 代理人: | 田凌涛 |
地址: | 210000 江苏*** | 国省代码: | 江苏;32 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 基于 联邦 学习 任务 模型 训练 方法 | ||
1.一种基于联邦学习与多任务学习的模型训练方法,用于同步实现针对至少一个目标神经网络的参数化训练,并且各个目标神经网络彼此之间具有相同结构的全连接层;其特征在于:基于参数服务器、以及各个工作节点终端,按如下步骤A至步骤C,同步实现各目标神经网络的参数化训练;
步骤A.分别针对各目标神经网络,将其中各个全连接层划为目标神经网络的后部模型,以及将其中剩余部分划为目标神经网络的前部模型,然后进入步骤B;
步骤B.参数服务器根据各个工作节点终端的参数属性,构建由满足预设参数要求的各个工作节点终端所组成的交换网络,由参数服务器负责各目标神经网络的前部模型,交换网络中的各个工作节点终端共同负责各目标神经网络的后部模型,然后进入步骤C;
步骤C.参数服务器与交换网络中的各个工作节点终端,根据各目标模型分别所对应的样本训练数据,应用多任务学习模式,针对各目标神经网络进行参数化训练,获得训练后的各个目标神经网络;
步骤B至步骤C的执行过程中,交换网络中各个工作节点终端分别执行发送模型参数到其他工作节点终端的进程、以及接收其他工作节点终端所发送模型参数的进程;其中,交换网络中各个工作节点终端分别按如下步骤III1至步骤III23,执行发送模型参数到其他工作节点终端的进程;
步骤III1.工作节点终端向参数服务器发送自身算力电量带宽然后进入步骤III2;
步骤III2.工作节点终端轮询等待来自参数服务器的确认信息,然后进入步骤III3;
步骤III3.工作节点终端初始化其所对应的拒绝节点字典为空,然后进入步骤III4;
步骤III4.工作节点终端开启其对参数服务端的监听,工作节点终端接收来自参数服务器的各目标神经网络前部模型的参数,然后进入步骤III5;
步骤III5.工作节点终端接收参数服务器所发送的各目标神经网络前部模型的参数,并更新其内部各目标神经网络前部模型的参数,然后进入步骤III6;
步骤III6.工作节点终端应用自身数据针对所接收各目标神经网络训练预设Cn轮,然后进入步骤III7;
步骤III7.工作节点终端将各目标神经网络前部模型的参数发送给参数服务端,然后进入步骤III8;
步骤III8.工作节点终端判断其自身电量与通讯资源是否均充足,即是否大于且是否大于是则进入步骤III9;否则返回步骤III17;
步骤III9.工作节点终端向参数服务器发送加入交换网络的申请,然后进入步骤III10;
步骤III10.工作节点终端向参数服务器发送请求加入网络列表的申请,然后进入步骤III11;
步骤III11.工作节点终端判断是否接收到来自参数服务器发送的网络列表,是则工作节点终端开启接收其他工作节点终端所发送模型参数的进程、以及开启对其他工作节点终端所发送参数的拒绝接收进程,并进入步骤III12;否则返回步骤III10;
步骤III12.工作节点终端从所接收网络列表中随机选择一个其他工作节点终端,进入步骤III13;
步骤III13.工作节点终端判断其所选择的其他工作节点终端是否存在于其拒绝节点字典内,是则返回步骤III12;否则进入步骤III14;
步骤III14.工作节点终端将其所接收各目标神经网络中后部模型的参数发送给其所选择的其他工作节点终端,然后进入步骤III15;
步骤III15.工作节点终端判断是否接收到来自其所选择的其他工作节点终端的拒绝消息,是则将所选择的其他工作节点终端的IP地址加入工作节点终端所对应的拒绝节点字典中,并定义拒绝节点字典中该所对应的value值为拒绝次数然后进入步骤III16;否则直接进入步骤III16;
步骤III16.工作节点终端判断其发送数据所到达的其他工作节点终端的数量是否大于预设数量阈值是则进入步骤III17;否则返回步骤III12;
步骤III17.工作节点终端所对应拒绝节点字典的各个value分别自减1进行更新,然后进入步骤III18;
步骤III18.工作节点终端删除拒绝节点字典中value等于0的key-value,然后进入步骤III19;
步骤III19.使用工作节点终端自身数据进行Cm轮训练,然后进入步骤III20;
步骤III20.工作节点终端测试其所接收各目标神经网络的准确率损失然后进入步骤III21;
步骤III21.工作节点终端关闭所有监听,然后进入步骤III22;
步骤III22.工作节点终端判断其所接收各目标神经网络的准确率是否大于预设准确率阈值且损失小于预设损失阈值是则工作节点终端完成对其所接收各目标神经网络后部模型的训练;否则进入步骤III23;
步骤III23.工作节点终端判断是否所接收到来自参数服务器端的各目标神经网络前部模型参数的更新,是则返回步骤III4;否则返回步骤III8。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于河海大学,未经河海大学许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202011194414.4/1.html,转载请声明来源钻瓜专利网。