[发明专利]一种基于Tensorflow的训练模型保存方法及驱动器、计算服务器有效
申请号: | 201810162033.4 | 申请日: | 2018-02-27 |
公开(公告)号: | CN108446173B | 公开(公告)日: | 2022-04-05 |
发明(设计)人: | 袁建勇;余远铭;王超 | 申请(专利权)人: | 华为技术有限公司 |
主分类号: | G06F9/48 | 分类号: | G06F9/48;G06F9/50 |
代理公司: | 深圳市深佳知识产权代理事务所(普通合伙) 44285 | 代理人: | 王仲凯 |
地址: | 518129 广东*** | 国省代码: | 广东;44 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 基于 tensorflow 训练 模型 保存 方法 驱动器 计算 服务器 | ||
本申请提供了一种基于Tensorflow的训练模型保存方法:驱动器获取并存储第一标识和第二标识,第一标识为运行有参数服务器的计算设备的标识,第二标识为与参数服务器运行在同一个计算设备的计算服务器的标识。计算服务器从驱动器获取第一标识和第二标识,在确认计算服务器运行的计算设备的标识与第一标识相同,且计算服务器的标识与第二标识相同的情况下,存储训练模型,从而提高训练模型保存的成功率。
技术领域
本申请涉及电子信息领域,尤其涉及一种基于Tensorflow的训练模型保存方法及驱动器、计算服务器。
背景技术
张量流Tensorflow是一款Google出品的机器学习模型,它提供了分布式的机器学习以及深度学习能力。
图1为Tensorflow常见的一种架构Tensorflow On Spark的结构示意图,Tensorflow On Spark包括以下逻辑单元:Spark驱动器Driver、计算服务器和参数服务器。其中,Spark Driver将训练任务调度到多个计算服务器上,并向每个计算服务器分发训练数据。计算服务器依据训练任务和训练数据执行训练过程,得到模型参数的反馈值,参数服务器依据反馈值,修正模型参数(例如模型参数包括各层神经网络的权重与偏差),并在训练结束后,保存得到的模型参数。在训练结束后,计算服务器保存训练得到的模型。
一个或多个计算服务器运行在一个计算设备(计算设备为实体设备或虚拟设备)上,全部计算服务器分布运行在多个计算设备上。参数服务器运行在多个计算设备中的一个计算设备上。
而在实际应用中,图1所示的Tensorflow On Spark架构,可能出现训练模型保存失败的问题。
发明内容
申请人在研究的过程中发现,保存训练模型的计算服务器与参数服务器运行在同一个计算设备,是保证训练模型成功保存的一个关键因素。因此,本申请提供了一种基于Tensorflow的训练模型保存方法及驱动器、计算服务器,目的在于使得保存训练模型的计算服务器与参数服务器运行在同一个计算设备,以解决训练模型保存失败的问题。
本申请的第一方面提供了一种基于Tensorflow的训练模型保存方法,包括:驱动器获取第一标识,所述第一标识为运行有参数服务器的计算设备的标识。所述驱动器获取第二标识,所述第二标识为与所述参数服务器运行在同一个计算设备的计算服务器的标识。所述驱动器存储所述第一标识和所述第二标识,所述第一标识和所述第二标识为所述计算服务器存储训练模型的依据。基于所述第一标识和所述第二标识的定义,驱动器存储第一标识和第二标识后,为计算服务器确定自身是否与参数服务器运行在同一个计算设备上提供依据,使得与参数服务器运行在同一个计算设备上的计算服务器存储训练的模型,从而提高模型存储的成功率。
本申请的第二方面提供了一种基于Tensorflow的驱动器,包括:第一获取模块、第二获取模块和存储模块。其中,第一获取模块用于获取第一标识,所述第一标识为运行有参数服务器的计算设备的标识。第二获取模块用于获取第二标识,所述第二标识为与所述参数服务器运行在同一个计算设备的计算服务器的标识。存储模块用于存储所述第一标识和所述第二标识。基于Tensorflow的驱动器能够提高计算服务器存储训练的模型的成功率。
在一个实现方式中,在所述驱动器获取第二标识之前,还包括:所述驱动器将训练任务调度到多个计算服务器上,并向每个计算服务器发送训练数据,使得接收到所述训练任务和所述训练数据的每个计算服务器执行训练过程,得到模型参数的反馈值,所述反馈值用于所述参数服务器修正所述模型参数。所述驱动器收集执行所述训练过程的各个计算服务器所在的计算设备的标识。所述驱动器确定运行在目标计算设备上的所述计算服务器为所述与所述参数服务器运行在同一个计算设备的计算服务器,所述目标计算设备为与所述参数服务器运行的计算设备具有相同标识的计算设备。
在一个实现方式中,所述计算设备的标识包括:所述计算设备的IP地址。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于华为技术有限公司,未经华为技术有限公司许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/201810162033.4/2.html,转载请声明来源钻瓜专利网。