利索能及
我要发布
收藏
专利号: 2023110125464
申请人: 中国矿业大学
专利类型:发明专利
专利状态:已下证
更新日期:2025-10-14
缴费截止日期: 暂无
联系人

摘要:

权利要求书:

1.一种基于中间层辅助特征模块融合匹配的知识蒸馏方法,其特征在于,步骤如下:步骤S1、在CIFAR‑100数据集中随机采集K幅带标签的图像,10000

步骤S2、根据卷积层的深度和特征图的大小,将教师主干网络划分为n个教师模块,学生主干网络划分为n个学生模块,转入步骤S3;

步骤S3、利用教师模块构建学生分支网络,利用学生模块构建教师分支网络,再利用分支网络中包含的子模块构建辅助训练模块,具体如下:将步骤S2中教师主干网络和学生主干网络划分的n个模块分别用集合表示;教师模块的集合用 表示,T表示预训练教师主干网络, 表示教师主干网络的第i个教师模块;学生模块的集合用 表示,S表示学生网络, 表示学生主干网络的第i个学生模块;然后在教师模块 后延伸出分支,即依次接入n‑i个学生模块 以构成第v条教师分支网络分支,将这n‑i个学生模块称作该条教师分支网络的子模块,将第v条教师分支网络分支中的n‑i个子模块的集合记为 其中 表示该教师分支网络的第u个子模块;同理,在学生模块 后延伸出分支,依次接入n‑i个教师模块 以构成第v条学生分支网络,将这n‑i个教师模块称作该条学生分支网络的子模块,将这第v条学生分支网络中的n‑i个子模块的集合记为其中 表示该学生分支网络的第u个子模块;最多共有n‑1个学生网络分支和n‑1个教师网络分支,即1≤v≤n‑1且1≤u≤n‑i;最后将所有学生分支网络的子模块集合BT1,BT2,...,BTv,…,BTn‑1和所有教师分支网络的子模块集合BS1,BS2,...,BSv,…,BSn‑1共同构成辅助训练模块Baux={BT1,BT2,...,BTv,…,BTn‑1;BS1,BS2,...,BSv,…,BSn‑1};

转入步骤S4;

步骤S4、提取步骤S2中各主干网络的输出特征以及步骤S3中辅助训练模块中各分支网络的输出特征,利用教师主干网络的输出特征和学生主干网络的输出特征计算传统蒸馏损失,利用辅助训练模块中各分支网络的输出特征与相应的主干网络的输出特征计算辅助训练损失,转入步骤S5;

步骤S5、制定分组融合策略:

利用步骤S3中辅助训练模块中功能相对应的教师分支网络的子模块和学生分支网络的子模块共同构成n‑1个功能组,具体分组策略如下:按照相同位置的模块承担相同功能这个规则,利用第1个教师分支网络的第1个子模块的输出特征 和第1个学生分支网络的第1个子模块 的输出特征 共同建立第一个功能组 利用第1个教师分支网络的第2个子模块 的输出 第2个教师分支网络的第1个子模块 的输出特征 第1个学生分支网络的第2个子模块 的输出特征 以及第2个学生分支网络的第1个子模块的输出特征 共同建立为第二个功能组……;依次取出所有教师分支网络和学生分支网络中所有的子模块,将其中执行相同功能的子模块的输出特征划分为一组,直至建立出第n‑1个功能组将所有功能组的集合定义为G

={G1,G2,…,Gn‑1},1≤v≤n‑1,1≤u≤n‑i;

转入步骤S6;

步骤S6、构建特征融合模块,并利用步骤S5中n‑1个功能组经过特征融合模块融合后的特征分别与学生主干网络中功能相对应的n‑1个学生模块的输出特征计算特征融合损失,转入步骤S7;

步骤S7、将传统蒸馏损失、辅助训练损失以及特征融合损失加权求和,得到总的损失函数,并以此对学生网络的网络参数进行更新,最终获得训练好的学生网络,转入步骤S8;

步骤S8、将测试数据集输入到训练好的学生网络,输出测试集中每个样本对应的预测结果,测试训练好的学生网络的准确率。

2.根据权利要求1所述的基于中间层辅助特征模块融合匹配的知识蒸馏方法,其特征在于,步骤S4中,提取步骤S2中各主干网络的输出特征以及步骤S3中辅助训练模块中各分支网络的输出特征,利用教师主干网络的输出特征和学生主干网络的输出特征计算传统蒸馏损失,利用辅助训练模块中各分支网络的输出特征与相应的主干网络的输出特征计算辅助训练损失,具体如下:

首先将第v条学生分支网络中的第u个子模块的输出特征 和第v条教师分支网络中的第u个子模块的输出特征 分别表示为:其中, 表示第v条学生分支网络的第u个子模块的特征提取函数, 表示第v条学生分支网络的第u个子模块; 表示第v条教师分支网络的第u个子模块的特征提取函数,表示第v条教师分支网络的第u个子模块,1≤v≤n‑1,1≤u≤n‑i;

再将教师主干网络的输出特征 经过softmax函数处理后的输出定义为PT,学生主干网络的输出特征 经过softmax函数处理后的输出定义为PS:式中t表示温度的超参数;

利用PT、PS计算出教师主干网络和学生主干网络的输出层特征间的知识蒸馏损失即传统的知识蒸馏损失Lcla:

Lcla=KL(PT||PS)

最后将第v条教师分支网络的输出特征经softmax函数处理后的类概率定义为 将第v条学生分支网络的输出特征经softmax函数处理后的类概率定义为利用 PT计算出教师分支网络和教师主干网络的输出特征间的KL损失LTv,利用 PS计算出学生分支网络和学生主干网络的KL损失LSv:最后将辅助训练模块中各分支网络输出特征与主干网络的输出特征之间的辅助训练损失Laux重建为:

Laux=LTv+LSv。

3.根据权利要求2所述的基于中间层辅助特征模块融合匹配的知识蒸馏方法,其特征在于,步骤S6中构建特征融合模块,并利用步骤S5中n‑1个功能组经过特征融合模块融合后的特征分别与学生主干网络功能相对应的n‑1个模块的输出特征计算特征融合损失,具体如下:

首先由3个大小为1×1、步长为1的卷积层和一次concat操作构成特征融合模块,同时利用注意力机制将特征融合模块的不同通道设置不同的注意力卷积网络生成不同的融合权值;在此特征融合模块中采用特征迭代融合的方法,即每两个特征根据不同的融合权值进行一次融合,再将得到的融合特征与下一个特征进行融合,如此逐次进行迭代融合直至遍历功能组中的所有元素;

再将特征融合模块的融合函数定义为fm,将第k个功能组Gk经过特征融合模块的输出特征表示为

其中,1≤k≤n‑1;

将学生主干网络划分的学生模块集合 的除去第一个学生模块后的n‑1个学生模块的输出特征集合定义为 利用L2归一化损失函数计算功能组经过特征融合模块后的输出特征 和特征集合FSO中的输出特征 之间的特征融合损失Lfuse:

4.根据权利要求3所述的基于中间层辅助特征模块融合匹配的知识蒸馏方法,其特征在于,步骤S7中将传统蒸馏损失Lcla、辅助训练损失Laux以及特征融合损失Lfuse加权求和,得到总的损失函数Ltotality,并以此对学生网络的网络参数进行更新,最终获得训练好的学生网络,具体如下:Ltotality=λ1Lcla+λ2Laux+λ3Lfuse其中,λ1为传统知识蒸馏损失的权重超参数,λ2为辅助训练损失的权重超参数,λ3为特征融合损失函数的权重超参数。

5.根据权利要求4所述的基于中间层辅助特征模块融合匹配的知识蒸馏方法,其特征在于:λ1=0.5,λ2=0.1,λ3=0.1。

6.根据权利要求2所述的基于中间层辅助特征模块融合匹配的知识蒸馏方法,其特征在于:t=4。