[发明专利]一种基于数据并行策略的分布式深度学习方法及系统有效
申请号: | 201810662859.7 | 申请日: | 2018-06-25 |
公开(公告)号: | CN109032671B | 公开(公告)日: | 2022-05-03 |
发明(设计)人: | 李明;侯孟书;詹思瑜;董浩;王瀚;席瑞;董林森 | 申请(专利权)人: | 电子科技大学 |
主分类号: | G06F9/38 | 分类号: | G06F9/38;G06N3/04;G06N3/08 |
代理公司: | 电子科技大学专利中心 51203 | 代理人: | 周刘英 |
地址: | 611731 四川省成*** | 国省代码: | 四川;51 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 基于 数据 并行 策略 分布式 深度 学习方法 系统 | ||
1.一种基于数据并行策略的分布式深度学习方法,其特征在于,包括下列步骤:
步骤1:输入用户基于PyTorch编写的待训练的神经网络模型,得到待训练的PyTorch神经网络模型P-Model;以及用户为所述P-Model设置下列训练参数:
分布式训练参数,包括通信阈值、工作节点数量N和分布式更新算法;
工作节点训练参数,包括迭代次数、批尺寸batch size、损失函数和工作节点的optimizer;
全局训练参数,包括精度期望值和参数服务器的optimizer;
步骤2:将P-Model、工作节点训练参数和分布式训练参数中的通信阈值发送到N个工作节点,以及将P-Model、全局训练参数和分布式训练参数中的分布式更新算法发送至参数服务器,将参数服务器上的P-Model定义为全局神经网络模型;
步骤3:将训练所述P-Model的训练数据转化为RDD数据类型;并通过大数据分布式处理引擎Spark将转化后的训练数据等分为N份后分发到所述N个工作节点;
步骤4:各工作节点基于本地的训练数据和工作节点训练参数,对本地的P-Model进行迭代训练,更新本地神经网络模型参数;
各工作节点在满足通信阈值时,计算本地神经网络模型参数的更新量并上传至参数服务器,其中更新量为当前神经网络模型参数与上次发生通信时的神经网络模型参数的差值;
参数服务器基于工作节点上传的更新量、分布式更新算法、以及全局训练参数训练全局神经网络模型,更新全局神经网络模型参数并下发至工作节点,工作节点将接收的全局神经网络模型参数同步至本地的P-Model;
步骤5:重复执行步骤4,至到满足模型训练终止条件,参数服务器将全局神经网络模型输出;
其中,模型训练终止条件为:各工作节点的实际迭代训练次数均达到工作节点训练参数中的迭代次数;
或者模型训练终止条件为:参数服务器每次更新全局神经网络模型参数后,训练精度达到精度期望值。
2.如权利要求1所述的方法,其特征在于,参数服务器更新全局神经网络模型参数具体方式为:将参数服务器的optimizer的梯度替换为工作节点上传的更新量;再通过所述参数服务器的optimizer对全局神经网络模型的参数进行更新。
3.如权利要求1或2所述的方法,其特征在于,分布式更新算法为异步随机梯度下降ASGD或Hogwild!。
4.如权利要求1或2所述的方法,其特征在于,参数服务器将全局神经网络模型存入HDFS。
5.如权利要求1或2所述的方法,其特征在于,所述通信阈值包括通信粒度、间隔通信轮数M,当工作节点在本地训练M轮通信粒度后,认为满足通信阈值。
6.如权利要求5所述的方法,其特征在于,所述通信粒度为epoch或batch。
7.一种基于数据并行策略的分布式深度学习系统,其特征在于,包括大数据分布式处理引擎Spark、PyTorch深度学习训练框架、轻量级Web框架Flask、urllib2模块、pickle模块、参数设置模块和数据转化模块;
其中,所述PyTorch深度学习训练框架用于用户编写待训练的PyTorch神经网络模型P-Model;
参数设置模块用于设置所述P-Model的分布式训练参数,工作节点训练参数和全局训练参数;其中,所述分布式训练参数,包括通信阈值、工作节点数量N和分布式更新算法;所述工作节点训练参数,包括:迭代次数、批尺寸batch size、损失函数和工作节点的optimizer;所述全局训练参数,包括:精度期望值和参数服务器的optimizer;
轻量级Web框架Flask根据P-Model、分布式更新算法和全局训练参数建立参数服务器;
所述大数据分布式处理引擎Spark根据用户设置的工作节点数量N选取分布式集群中N个节点作为工作节点;并将所述P-Model、工作节点训练参数和通信阈值传送至工作节点并建立本地的PyTorch模型训练模块;
所述数据转化模块用于将训练所述P-Model的训练数据转化为Spark支持的RDD数据类型,并通过大数据分布式处理引擎Spark将转化后训练数据等分为N份后分发到所述N个工作节点;
各PyTorch模型训练模块通过urllib2模块和pickle模块与参数服务器进行参数交互,所述urllib2模块用于PyTorch模型训练模块与参数服务器之间的网络通信,pickle模块用于对待发送的参数进行序列化处理,以及对接收的参数进行反序列化处理;
各PyTorch模型训练模块基于工作节点训练参数,以及分发到的训练数据对本地的P-Model进行迭代训练,不断更新本地神经网络模型参数;并在满足通信阈值时,计算本地神经网络模型参数的更新量并上传至参数服务器,其中更新量为当前神经网络模型参数与上次发生通信时的神经网络模型参数的差值;
参数服务器上的P-Model为全局神经网络模型,参数服务器基于从工作节点收到的更新量、用户设置的分布式更新算法、以及全局训练参数,更新全局神经网络模型参数并将其返回至该工作节点,该工作节点上的PyTorch模型训练模块根据接收的全局神经网络模型参数对本地神经网络模型参数进行同步;
参数服务器监控各PyTorch模型训练模块的训练过程,当满足模型训练终止条件时,参数服务器将全局神经网络模型输出;
其中,模型训练终止条件为:各工作节点的实际迭代训练次数均达到工作节点训练参数中的迭代次数;
或者模型训练终止条件为:参数服务器每次更新全局神经网络模型参数后,训练精度达到精度期望值。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于电子科技大学,未经电子科技大学许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/201810662859.7/1.html,转载请声明来源钻瓜专利网。
- 数据显示系统、数据中继设备、数据中继方法、数据系统、接收设备和数据读取方法
- 数据记录方法、数据记录装置、数据记录媒体、数据重播方法和数据重播装置
- 数据发送方法、数据发送系统、数据发送装置以及数据结构
- 数据显示系统、数据中继设备、数据中继方法及数据系统
- 数据嵌入装置、数据嵌入方法、数据提取装置及数据提取方法
- 数据管理装置、数据编辑装置、数据阅览装置、数据管理方法、数据编辑方法以及数据阅览方法
- 数据发送和数据接收设备、数据发送和数据接收方法
- 数据发送装置、数据接收装置、数据收发系统、数据发送方法、数据接收方法和数据收发方法
- 数据发送方法、数据再现方法、数据发送装置及数据再现装置
- 数据发送方法、数据再现方法、数据发送装置及数据再现装置