[发明专利]一种联邦学习模型训练方法、装置及联邦学习系统有效
申请号: | 202011473442.X | 申请日: | 2020-12-15 |
公开(公告)号: | CN112232528B | 公开(公告)日: | 2021-03-09 |
发明(设计)人: | 曾令仿;银燕龙;何水兵;毛旷;杨弢;任祖杰;陈刚 | 申请(专利权)人: | 之江实验室 |
主分类号: | G06N20/20 | 分类号: | G06N20/20;G06F21/60 |
代理公司: | 杭州求是专利事务所有限公司 33200 | 代理人: | 应孔月 |
地址: | 310023 浙江省杭州市余*** | 国省代码: | 浙江;33 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 联邦 学习 模型 训练 方法 装置 系统 | ||
1.一种联邦学习模型训练方法,所述方法应用于包括一个云端联邦学习子系统、个边缘计算服务器和个端设备的联邦机器学习训练,其中,个边缘计算服务器均与云端联邦学习子系统相连,每个边缘计算服务器与一个或多个端设备建立连接,,所述方法包括:
所述云端联邦学习子系统把公钥key分发给所述边缘计算服务器和所述端设备,用以对训练过程中需要交换的数据进行加密;
对所述边缘计算服务器负责所辖区域内端设备的梯度进行更新;
对所述云端联邦学习子系统负责所辖边缘计算服务器的梯度进行更新;
每T1轮次的本地更新会触发一次边缘计算服务器进行梯度收集,每T2轮次的边缘计算服务器更新会触发一次云端联邦学习子系统进行梯度收集,其中T1为端设备分歧阶段内部轮次,T2为边缘计算服务器分歧阶段内部轮次;
当模型达到长尾阶段,边缘计算服务器会根据自身维护的表,让每个被截断的端设备补足本地训练,云端联邦学习子系统会根据自身维护的表,让每个被截断的边缘计算服务器补足本地训练,其中为端设备k1训练截断比例数组,上标k1为端设备索引,总数为,为边缘计算服务器k2训练截断比例数组,上标k2为边缘计算服务器索引,总数为。
2.根据权利要求1所述的一种联邦学习模型训练方法,其特征在于,对边缘计算服务器负责所辖区域内端设备的梯度进行更新,包括:
每训练个本地轮次,端设备将会记录一次自身私有模型此时的与,待本地训练完成次之后,会首先将维护的个值发送给边缘计算服务器,其中为端设备k1第u1个分歧阶段的模型增量,为端设备k1第u1个分歧阶段的模型损失,u1为端设备本地分歧阶段索引,总数为U1;
边缘计算服务器针对每个端设备形成维损失数组,按照筛选出每轮需要截断的端设备,并将其少训练的维护在边缘计算服务器,将端设备k1该轮通信所应该截断的轮次发送给端设备;
端设备在收到后,会将第轮本地训练后的梯度上传给边缘计算服务器并进行联邦模型的参数更新。
3.根据权利要求1所述的一种联邦学习模型训练方法,其特征在于,对云端联邦学习子系统负责所辖边缘计算服务器的梯度进行更新,包括:
每训练个本地轮次,边缘计算服务器将会记录一次自身私有模型此时的与,待本地训练完成次之后,会首先将维护的个值发送给云端联邦学习子系统,其中为边缘计算服务器k2第u2个分歧阶段的模型增量,为边缘计算服务器k2第u2个分歧阶段的模型损失,u2为边缘计算服务器本地分歧阶段索引,总数为U2;
云端联邦学习子系统针对每个边缘计算服务器形成维损失数组,按照筛选出每轮需要截断的边缘计算服务器,并将其少训练的维护在云端联邦学习子系统,将边缘计算服务器k2该轮通信所应该截断的轮次发送给边缘计算服务器;
边缘计算服务器在收到后,会将第轮本地训练后的梯度上传给云端联邦学习子系统并进行联邦模型的参数更新。
4.一种联邦学习模型训练装置,所述装置应用于包括一个云端联邦学习子系统、个边缘计算服务器和个端设备的联邦机器学习训练,其中,个边缘计算服务器均与云端联邦学习子系统相连,每个边缘计算服务器与一个或多个端设备建立连接,,所述装置包括:
分发单元,用于所述云端联邦学习子系统把公钥key分发给所述边缘计算服务器和所述端设备,用以对训练过程中需要交换的数据进行加密;
第一更新单元,用于对所述边缘计算服务器负责所辖区域内端设备的梯度进行更新;
第二更新单元,用于对所述云端联邦学习子系统负责所辖边缘计算服务器的梯度进行更新;
收集单元,用于每T1轮次的本地更新会触发一次边缘计算服务器进行梯度收集,每T2轮次的边缘计算服务器更新会触发一次云端联邦学习子系统进行梯度收集,其中T1为端设备分歧阶段内部轮次,T2为边缘计算服务器分歧阶段内部轮次;
训练单元,用于当模型达到长尾阶段,边缘计算服务器会根据自身维护的表,让每个被截断的端设备补足本地训练,云端联邦学习子系统会根据自身维护的表,让每个被截断的边缘计算服务器补足本地训练,其中为端设备k1训练截断比例数组,上标k1为端设备索引,总数为,为边缘计算服务器k2训练截断比例数组,上标k2为边缘计算服务器索引,总数为。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于之江实验室,未经之江实验室许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202011473442.X/1.html,转载请声明来源钻瓜专利网。
- 上一篇:一种用于双足机器人下台阶的稳定行走方法
- 下一篇:一种数据缓冲方法及其系统