主页 > 人工智能  > 

知识蒸馏知识点

知识蒸馏知识点

1基于kl散度计算,学生模型需要用log_softmax处理

2 为了避免温度对梯度的影响,loss*T**2

操作 目的 教师 / 学生输出除以 软化概率分布,暴露类别间关系 损失乘以 抵消温度对梯度的缩放,维持梯度量级稳定,确保训练收敛性

import torch import torch.nn.functional as F # 原始logits(未缩放) z_teacher = torch.tensor([[3.0, 1.0, 0.5]]) z_student = torch.tensor([[2.5, 0.8, 0.3]], requires_grad=True) # 直接启用梯度 # 温度参数 T = 4.0 # 计算KL散度损失(带温度缩放) P = F.softmax(z_teacher / T, dim=1) log_Q = F.log_softmax(z_student / T, dim=1) # 此时梯度已追踪 # 损失计算 loss_unscaled = F.kl_div(log_Q, P, reduction='batchmean') loss_scaled
标签:

知识蒸馏知识点由讯客互联人工智能栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“知识蒸馏知识点