[发明专利]神经网络的生成方法和装置在审
申请号: | 202210591391.3 | 申请日: | 2022-05-27 |
公开(公告)号: | CN114912585A | 公开(公告)日: | 2022-08-16 |
发明(设计)人: | 沈力;陶大程 | 申请(专利权)人: | 京东科技信息技术有限公司 |
主分类号: | G06N3/04 | 分类号: | G06N3/04;G06N3/08;H04L67/55 |
代理公司: | 北京英赛嘉华知识产权代理有限责任公司 11204 | 代理人: | 王达佐;马晓亚 |
地址: | 100176 北京市大兴区经济*** | 国省代码: | 北京;11 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 神经网络 生成 方法 装置 | ||
本公开的实施例公开了神经网络的生成方法和装置。该方法的一具体实施方式包括:获取待训练的神经网络,神经网络包括特征提取网络、筛选网络和输出网络,特征提取网络用于生成概率向量,筛选网络用于对输入的概率向量进行重参数处理,对重参数处理结果应用基于最优传输的Top‑K算法,输出网络用于根据筛选网络的输出生成处理结果;获取训练数据集;将训练数据集中的待处理样本作为神经网络的输入,将输入的待处理样本对应的处理结果样本作为神经网络的期望输出,利用反向传播算法训练神经网络。该实施方式不仅基于最优传输的Top‑K算法解决了Top‑K算法无法计算梯度的问题,同时缓解了基于最优传输的Top‑K算法的神经网络的求解不准确的问题。
技术领域
本公开的实施例涉及计算机技术领域,具体涉及神经网络神经网络的生成方法和装置。
背景技术
Top-K Coreset是一种通用的排序抽样的方法,能够实现用较小集合逼近原始集合。一些情况下,Top-K Coreset被应用于神经网络中。为了能够训练基于Top-K Coreset的神经网络,需要损失函数在每个更新步骤中相对于输入都是可微的,但Top-K操作的实现算法通常涉及交换索引等操作,无法计算其梯度,从而很难被整合到神经网络的训练过程中。
基于此,现有常用的训练方法是采用两阶段的训练方式,具体使用特征提取部分的代理损失来训练神经网络所包括的特征提取部分,然后利用训练好的特征提取部分进行特征提取,再利用Top-K Coreset等根据提取的特征进行后续处理。这种方式在训练过程中完全规避了Top-K操作,但会使得训练和最终处理结果不一致。
发明内容
本公开的实施例提出了神经网络的生成方法和装置。
第一方面,本公开的实施例提供了一种神经网络神经网络的生成方法,该方法包括:获取待训练的神经网络,其中,神经网络包括特征提取网络、筛选网络和输出网络,特征提取网络用于提取样本特征,以及根据样本特征生成概率向量,筛选网络用于对输入的概率向量进行重参数处理,以及对重参数处理结果应用基于最优传输的Top-K算法,输出网络用于根据筛选网络的输出生成处理结果;获取训练数据集,其中,训练数据集中的训练数据包括待处理样本和待处理样本对应的标签数据,待处理样本的类型包括以下至少一项:文本、图像、音频以及视频;将训练数据集中的待处理样本作为神经网络的输入,将输入的待处理样本对应的标签数据作为神经网络的期望输出,利用反向传播算法训练神经网络,以得到训练完成的神经网络。
在一些实施例中,上述筛选网络利用Gumbel Trick对输入的概率向量进行重参数处理。
在一些实施例中,上述筛选网络利用Sinkhorn算法实现基于最优传输的Top-K算法。
在一些实施例中,上述概率向量中的各元素分别与预设信息集中的信息对应;以及输出网络生成的处理结果用于指示从预设信息集中选取的信息。
第二方面,本公开的实施例提供了一种信息推送方法,该方法包括:获取候选推送信息集;利用预先训练的神经网络从候选推送信息集中选取信息,其中,神经网络利用如第一方面中最后一种实现方式描述的方法训练得到;对从候选推送信息集中选取的信息进行推送。
第三方面,本公开的实施例提供了一种神经网络的生成装置,该装置包括:第一获取单元,被配置成获取待训练的神经网络,其中,神经网络包括特征提取网络、筛选网络和输出网络,特征提取网络用于提取样本特征,以及根据样本特征生成概率向量,筛选网络用于对输入的概率向量进行重参数处理,以及对重参数处理结果应用基于最优传输的Top-K算法,输出网络用于根据筛选网络的输出生成处理结果;第二获取单元,被配置成获取训练数据集,其中,训练数据集中的训练数据包括待处理样本和待处理样本对应的标签数据,待处理样本的类型包括以下至少一项:文本、图像、音频、视频;训练单元,被配置成将训练数据集中的待处理样本作为神经网络的输入,将输入的待处理样本对应的标签数据作为神经网络的期望输出,利用反向传播算法训练神经网络,以得到训练完成的神经网络。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于京东科技信息技术有限公司,未经京东科技信息技术有限公司许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202210591391.3/2.html,转载请声明来源钻瓜专利网。