[发明专利]用于训练预测模型的方法和装置有效
申请号: | 202010116709.3 | 申请日: | 2020-02-25 |
公开(公告)号: | CN111340220B | 公开(公告)日: | 2023-10-20 |
发明(设计)人: | 希滕;张刚;温圣召 | 申请(专利权)人: | 北京百度网讯科技有限公司 |
主分类号: | G06N3/0985 | 分类号: | G06N3/0985;G06N3/084;G06N3/044;G06N3/045;G06N3/0464 |
代理公司: | 北京英赛嘉华知识产权代理有限责任公司 11204 | 代理人: | 王达佐;马晓亚 |
地址: | 100085 北京市*** | 国省代码: | 北京;11 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 用于 训练 预测 模型 方法 装置 | ||
本公开涉及人工智能领域。本公开的实施例公开了用于训练预测模型的方法和装置。该预测模型用于预测神经网络结构的性能,该方法包括通过采样操作训练预测模型;采样操作包括:从已训练完成的超网络中采样出子网络,并对采样出的子网络进行训练,得到训练完成的子网络的性能信息;基于训练完成的子网络和对应的性能信息构建样本数据,并利用样本数据训练预测模型;响应于确定当前采样操作中训练得到的预测模型的精度不满足预设的条件,执行下一次采样操作,并在下一次采样操作中增加采样的子网络的数量。该方法可以降低神经网络模型结构的搜索成本。
技术领域
本公开的实施例涉及计算机技术领域,具体涉及人工智能技术领域,尤其涉及用于训练预测模型的方法和装置。
背景技术
随着人工智能技术和数据存储技术的发展,深度神经网络在许多领域取得了重要的成果。深度神经网络结构的设计对其性能具有直接的影响。传统的深度神经网络结构的设计由人工根据经验完成。人工设计网络结构需要大量的专家知识,并且针对不同的任务或应用场景需要分别针对性地进行网络结构的设计,成本较高。
NAS(neural architecture search,自动化神经网络结构搜索)是用算法代替繁琐的人工操作,自动搜索出最佳的神经网络架构。现有的模型结构自动搜索只能基于特定的约束条件进行搜索,例如针对指定的硬件设备型号进行搜索。然而,实际场景中的约束条件比较复杂,且变化很多,涉及到多种硬件种类,例如多种不同型号处理器。对每一种硬件,搜索约束也是繁多的,例如不同的延时约束。现有的方法需要针对每一种约束条件执行网络结构搜索,大量重复的网络结构搜索任务会消耗很多的计算资源,成本非常高。
发明内容
本公开的实施例提出了用于训练预测模型的方法和装置、电子设备和计算机可读介质。
第一方面,本公开的实施例提供了一种用于训练预测模型的方法,预测模型用于预测神经网络结构的性能,用于训练预测模型的方法包括通过采样操作训练预测模型;采样操作包括:从已训练完成的超网络中采样出子网络,并对采样出的子网络进行训练,得到训练完成的子网络的性能信息;基于训练完成的子网络和对应的性能信息构建样本数据,并利用样本数据训练预测模型;响应于确定当前采样操作中训练得到的预测模型的精度不满足预设的条件,执行下一次采样操作,并在下一次采样操作中增加采样的子网络的数量。
在一些实施例中,上述从已训练完成的超网络中采样出子网络,包括:采用初始的递归神经网络从已训练完成的超网络中采样出子网络;以及在对采样出的子网络进行训练之前,采样操作还包括:基于训练好的子网络的性能信息生成反馈信息,以基于反馈信息迭代更新递归神经网络;基于迭代更新后的递归神经网络重新从已训练完成的超网络中采样出子网络。
在一些实施例中,上述从已训练完成的超网络中采样出子网络,包括:从已训练完成的超网络中采样出未被采样过的子网络;以及上述基于训练完成的子网络和对应的性能信息构建样本数据,包括:基于当前采样操作中采样出的子网络和对应的性能信息、以及上一次采样操作中采样出的子网络和对应的性能信息构建样本数据。
在一些实施例中,上述采样操作还包括:响应于确定预测模型的精度满足预设的条件,基于当前的采样操作的训练结果生成训练完成的预测模型。
在一些实施例中,上述方法还包括:基于训练完成的预测模型对预设的模型结构搜索空间内的模型结构的性能预测结果,以及预设的深度学习任务场景的性能约束条件,在模型结构搜索空间中搜索出满足性能约束条件的神经网络模型结构。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于北京百度网讯科技有限公司,未经北京百度网讯科技有限公司许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202010116709.3/2.html,转载请声明来源钻瓜专利网。