1.一种用于分类的神经网络的训练方法,其特征在于,包括:获取样本数据,并将所述样本数据分别输入至教师网络和待训练的学生网络,得到第一预测结果和第二预测结果;
基于所述第一预测结果、所述第二预测结果以及所述学生网络的当前网络参数,确定预更新的第一网络参数;
基于所述预更新的第一网络参数对所述学生网络和所述教师网络对应的温度系数进行更新,并基于更新后的温度系数对所述学生网络的当前网络参数进行更新,得到更新后的学生网络。
2.根据权利要求1所述的方法,其特征在于,所述基于所述第一预测结果、所述第二预测结果以及所述学生网络的当前网络参数,确定预更新的第一网络参数,包括:获取所述学生网络的当前网络参数;
基于所述第一预测结果和所述第二预测结果,确定训练损失;
基于训练损失对获取的所述学生网络的当前网络参数进行调整,确定预更新的第一网络参数。
3.根据权利要求1或2所述的方法,其特征在于,所述基于所述预更新的第一网络参数对所述学生网络和所述教师网络对应的温度系数进行更新,包括:基于所述预更新的第一网络参数,对与所述学生网络训练相关联的元参数进行更新;
其中,所述元参数用于确定所述学生网络和所述教师网络对应的温度系数;
基于更新后的元参数对温度系数进行更新。
4.根据权利要求3所述的方法,其特征在于,所述样本数据包括训练样本数据和验证样本数据;
所述将所述样本数据分别输入至教师网络和待训练的学生网络,得到第一预测结果和第二预测结果,包括:将所述训练样本数据分别输入至教师网络和待训练的学生网络,得到第一预测结果和第二预测结果;
所述基于所述预更新的第一网络参数,对与所述学生网络训练相关联的元参数进行更新,包括:将所述预更新的第一网络参数作为所述学生网络的当前网络参数,得到预更新学生网络;
将所述验证样本数据输入至所述预更新学生网络,得到第三预测结果;
基于所述第三预测结果和所述验证样本数据对应的标注信息对所述元参数进行更新。
5.根据权利要求4所述的方法,其特征在于,所述基于所述第三预测结果和所述验证样本数据对应的标注信息对所述元参数进行更新,包括:基于所述第三预测结果和所述验证样本数据对应的标注信息,确定验证损失;
基于所述验证损失对所述元参数进行更新。
6.根据权利要求5所述的方法,其特征在于,所述基于所述第三预测结果和所述验证样本数据对应的标注信息,确定验证损失,包括:基于所述第三预测结果和所述验证样本数据对应的标注信息,确定所述预更新学生网络预测错误的错误样本数据;
基于所述第三预测结果中所述错误样本数据对应的预测结果的置信度信息和所述验证样本数据对应的标注信息对应的预设置信度信息,确定所述验证损失。
7.根据权利要求6所述的方法,其特征在于,所述基于所述第三预测结果中所述错误样本数据对应的预测结果的置信度信息和所述验证样本数据对应的标注信息对应的预设置信度信息,确定所述验证损失,包括:针对任一错误样本数据,计算该错误样本数据的第三预测结果中各分类结果的置信度信息与,该错误与该错误样本数据的标注信息中各分类结果的预设置信度信息的差的平方和;
将各错误样本数据对应的平方和之和作为所述验证损失。
8.根据权利要求3~7任一所述的方法,其特征在于,所述基于更新后的温度系数对所述学生网络的当前网络参数进行更新,包括:将所述样本数据输入至所述学生网络,以通过所述更新后的温度系数和所述学生网络确定第四预测结果;以及,将所述样本数据输入至所述教师网络,以通过所述教师网络和所述更新后的温度系数确定第五预测结果;
基于所述第四预测结果和所述第五预测结果,重新确定训练损失,并基于重新确定的训练损失对所述学生网络的当前网络参数进行更新。
9.根据权利要求4~8任一所述的方法,其特征在于,所述元参数包括与所述学生网络训练相关联的参数生成网络的网络参数;
所述基于更新后的元参数对温度系数进行更新,包括:
将更新后的元参数作为参数生成网络的网络参数,得到更新后的参数生成网络;
基于更新后的参数生成网络以及初始温度系数,重新确定温度系数。
10.根据权利要求9所述的方法,其特征在于,所述温度系数包括所述学生网络对应的第一温度系数和所述教师网络对应的第二温度系数;
所述基于更新后的参数生成网络以及初始温度系数,重新确定温度系数,包括:基于更新后的参数生成网络以及所述初始温度系数,确定所述第一温度系数和所述第二温度系数。
11.根据权利要求9或10所述的方法,其特征在于,所述元参数包括所述学生网络对应的第一元参数和所述教师网络对应的第二元参数;
所述将更新后的元参数作为参数生成网络的网络参数,得到更新后的参数生成网络,包括:将更新后的第一元参数作为所述参数生成网络的网络参数,得到更新后的第一参数生成网络;以及,将更新后的第二元参数作为所述参数生成网络的网络参数,得到更新后的第二参数生成网络;
所述基于更新后的参数生成网络以及初始温度系数,重新确定温度系数,包括:基于所述更新后的第一参数生成网络以及所述初始温度系数,确定所述学生网络对应的第一温度系数;以及,基于所述更新后的第二参数生成网络以及所述初始温度系数,确定所述教师网络对应的第二温度系数。
12.一种图像分类方法,其特征在于,包括:
获取待检测图像;
基于权利要求1~11任一所述的用于分类的神经网络的训练方法训练得到的学生网络对所述待检测图像进行识别,得到所述待检测图像对应的分类结果。
13.一种用于分类的神经网络的训练装置,其特征在于,包括:第一获取模块,用于获取样本数据,并将所述样本数据分别输入至教师网络和待训练的学生网络,得到第一预测结果和第二预测结果;
确定模块,用于基于所述第一预测结果、所述第二预测结果以及所述学生网络的当前网络参数,确定预更新的第一网络参数;
更新模块,用于基于所述预更新的第一网络参数对所述学生网络和所述教师网络对应的温度系数进行更新,并基于更新后的温度系数对所述学生网络的当前网络参数进行更新,得到更新后的学生网络。
14.一种图像分类装置,其特征在于,包括:
第二获取模块,用于获取待检测图像;
识别模块,用于基于权利要求1~11任一所述的用于分类的神经网络的训练方法训练得到的学生网络对所述待检测图像进行识别,得到所述待检测图像对应的分类结果。
15.一种计算机设备,其特征在于,包括:处理器、存储器和总线,所述存储器存储有所述处理器可执行的机器可读指令,当计算机设备运行时,所述处理器与所述存储器之间通过总线通信,所述机器可读指令被所述处理器执行时执行如权利要求1至11任一项所述的用于分类的神经网络的训练方法的步骤,或执行如权利要求12所述的图像分类方法的步骤。
16.一种计算机可读存储介质,其特征在于,该计算机可读存储介质上存储有计算机程序,该计算机程序被处理器运行时执行如权利要求1至11任一项所述的用于分类的神经网络的训练方法的步骤,或执行如权利要求12所述的图像分类方法的步骤。