[发明专利]模型训练方法和装置在审
申请号: | 202210223406.0 | 申请日: | 2022-03-07 |
公开(公告)号: | CN114627331A | 公开(公告)日: | 2022-06-14 |
发明(设计)人: | 杨一博;陈亚鑫;马本腾;陶大程 | 申请(专利权)人: | 北京沃东天骏信息技术有限公司;北京京东世纪贸易有限公司 |
主分类号: | G06V10/764 | 分类号: | G06V10/764;G06V10/778;G06V10/26;G06V10/82;G06N3/04;G06N3/08;G06K9/62 |
代理公司: | 中原信达知识产权代理有限责任公司 11219 | 代理人: | 王志远;张一军 |
地址: | 100176 北京市北京经济技术*** | 国省代码: | 北京;11 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 模型 训练 方法 装置 | ||
本发明公开了一种模型训练方法和装置,涉及人工智能技术领域。该方法的一具体实施方式包括:将样本集中预先标注语义分割标签的多幅训练图像分别输入训练完成的教师模型和待训练的学生模型,将所述学生模型对所述训练图像中像素所属类别的预测结果与所述语义分割标签之间的概率分布差异确定为第一差异;使用第一差异结合第二差异和/或第三差异构造所述学生模型的损失函数来训练所述学生模型。该实施方式能够增强模型对小众类别和/或图像细节信息的表达能力。
技术领域
本发明涉及人工智能技术领域,尤其涉及一种模型训练方法和装置。
背景技术
语义分割是当今计算机视觉领域的关键问题之一,广泛应用在自动驾驶、虚拟现实、智能诊疗、遥感等领域,其能够通过对每一像素所属类别的推理和预测实现场景的完整理解。由于语义分割任务需要在像素级别理解复杂场景,因此往往需要更大规模的复杂模型来学习强大的特征表示能力以确保预测精度并使得模型拥有较好的泛化性,由于模型尺寸大以及计算成本高,容易导致资源占用高、响应速度慢等问题,同时不适合部署在终端设备。
目前,可以采用知识蒸馏方法解决以上问题,即通过复杂的教师模型来训练轻量的学生模型,并将轻量的学生模型部署在终端设备。实际的语义分割任务中,这种方法存在以下问题:第一,受到主要类别(像素占比较大的类别,类别指的是像素所属类别,即标签含有的类别)的影响,模型对小众类别(像素占比较小的类别)的表达能力较弱;第二,模型对图像中的局部信息和细节信息的表达能力较弱,特别是当关注的目标仅占据图像的较小范围而背景占据较大范围时,容易忽视目标。
发明内容
有鉴于此,本发明实施例提供一种模型训练方法和装置,能够在知识蒸馏过程中通过提取联合特征和/或分离特征来增强模型对小众类别和/或图像细节信息的表达能力。
为实现上述目的,根据本发明的一个方面,提供了一种模型训练方法。
本发明实施例的模型训练方法包括:将样本集中预先标注语义分割标签的多幅训练图像分别输入训练完成的教师模型和待训练的学生模型,将所述学生模型对所述训练图像中像素所属类别的预测结果与所述语义分割标签之间的概率分布差异确定为第一差异;以及,所述学生模型和所述教师模型都包括主体网络和连接在所述主体网络之后的广义归一化层;对于所述学生模型和所述教师模型的主体网络输出的、对应于所述多幅训练图像的特征图:转换为所述类别的联合特征后进入所述广义归一化层,和/或,基于预设的切分规则在高度和宽度维度被切分为多个分离特征后进入所述广义归一化层;其中,每一类别的联合特征包括对应于所述多幅训练图像的特征图中的像素属于该类别的概率数据;每一分离特征包括该特征图处在同一切分空间的像素属于所述类别的概率数据;使用第一差异结合第二差异和/或第三差异构造所述学生模型的损失函数来训练所述学生模型;其中,第二差异是基于所述学生模型和所述教师模型的所述联合特征确定的,第三差异是基于所述学生模型和所述教师模型的所述分离特征确定的。
可选地,任一类别的联合特征根据以下步骤确定:获取相应主体网络输出的、对应于所述多幅训练图像的多通道特征图中各像素属于该类别的概率数据;将各像素属于该类别的概率数据合并为该类别的联合特征。
可选地,所述分离特征进一步由所述特征图执行通道维度的切分、并经类别维度的聚合而形成;经通道、高度和宽度维度切分形成的任一切分空间对应于任一类别的分离特征包括:该切分空间的像素属于该类别的概率数据。
可选地,在所述学生模型形成联合特征的情况下,所述教师模型形成联合特征;在所述教师模型形成联合特征的情况下,所述学生模型形成联合特征;在所述学生模型形成分离特征的情况下,所述教师模型形成分离特征;在所述教师模型形成分离特征的情况下,所述学生模型形成分离特征;以及,所述使用第一差异结合第二差异和/或第三差异构造所述学生模型的损失函数,包括:将第一差异和第二差异的加权和确定为所述损失函数;或者,将第一差异和第三差异的加权和确定为所述损失函数;或者,将第一差异、第二差异和第三差异的加权和确定为所述损失函数。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于北京沃东天骏信息技术有限公司;北京京东世纪贸易有限公司,未经北京沃东天骏信息技术有限公司;北京京东世纪贸易有限公司许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202210223406.0/2.html,转载请声明来源钻瓜专利网。