混合精度训练场景中,对比学习损失函数的一个注意点

笔者在之前的大规模对比学习训练过程(训练CLIP)中,发现在混合精度训练时候,对比学习的交叉熵损失(带温度系数)容易出现的一个小问题,特此笔记下,希望对读者有所帮助。

FesianXu 20220603 at Baidu Search Team

前言

笔者在之前的大规模对比学习训练过程(训练CLIP)中,发现在混合精度训练时候,对比学习的交叉熵损失(带温度系数)容易出现的一个小问题,特此笔记下,希望对读者有所帮助。如有谬误请联系指出,本文遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明并且联系笔者,谢谢

联系方式:

e-mail: FesianXu@gmail.com

github: https://github.com/FesianXu

知乎专栏: 计算机视觉/计算机图形理论与应用

微信公众号


在对比学习训练过程中,经常会采用带温度系数的交叉熵损失,如式子(1-1)所示 其中的为温度系数,表示query和key,其中的表示正样本,而表示第个样本,其中query和key的定义见之前博文 [1],里面有更为细致的介绍,该损失函数在CLIP训练中有着应用,见博文[2]。而其中的温度系数很关键,是用于控制整个学习任务的『难易』程度的,这里我们举个例子。用以下代码为例,分两种情况, 当训练处于初始阶段时候, 打分接近均匀分布,那么假设,其中0.2为正样本打分,其他为负样本打分,注意到此时label位置的打分(也即是第0维打分0.2)是比其他位置的打分还要小的。我们选择四组温度系数,如以下代码所示,此时我们发现温度系数越小,其输出的损失值越大。这其实很好理解,由于对打分进行了扩大,而这种扩大是非线性的,参考指数函数的曲线,如Fig 1.所示,当线性扩大的时候,其指数是有着更为高阶的增大速度的。因此当的时候,np.exp(logit/t)[ 54.59815003 403.42879349 148.4131591 148.4131591 ],对比其原始的打分[0.2 0.3 0.25 0.25],我们从式子(1-1)中已经知道损失函数是正样本打分/样本打分总和的相反数,而当通过指数函数进行放大后,其损失就从原始的

变为,显然损失就被放大了,更易于模型刚开始的优化。而目标位置和其他位置的打分差别越大,其损失差别就会通过指数函数放大得更大,此时,我们认为越小的温度系数,模型训练越『简单』,可以认为此时我们将温度系数调小,是调小了整个训练任务的学习难度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def cross_entropy(logit, t, ind_pos=0):
print(logit)
print(np.exp(logit/t))
return -np.log(np.exp(logit[ind_pos]/t) / np.exp(logit/t).sum())

orig_score = [0.2,0.3,0.25,0.25]
orig_score = np.array(orig_score)
y0p05 = cross_entropy(orig_score, 0.05)
print(y0p05)
# 输出 loss: 2.626523375036445
# logit: [0.2 0.3 0.25 0.25]
# np.exp(logit/t): [ 54.59815003 403.42879349 148.4131591 148.4131591 ]

y0p5 = cross_entropy(orig_score, 0.5)
print(y0p5)
# 输出 loss: 1.4887933201471417
# logit: [0.2 0.3 0.25 0.25]
# np.exp(logit/t): [1.4918247 1.8221188 1.64872127 1.64872127]

y1p0 = cross_entropy(orig_score, 1.0)
print(y1p0)
# 输出 loss: 1.4369192960265726
# logit: [0.2 0.3 0.25 0.25]
# np.exp(logit/t): [1.22140276 1.34985881 1.28402542 1.28402542]

y2p0 = cross_entropy(orig_score, 2.0)
print(y2p0)
# 输出 loss: 1.4114506070510497
# logit: [0.2 0.3 0.25 0.25]
# np.exp(logit/t): [1.10517092 1.16183424 1.13314845 1.13314845]

Fig 1. 指数函数会将打分进行非线性扩大。

然而,当训练到一定程度时候,模型的打分已经比较置信了,假设为,那么同样采用以上代码,我们有输出。此时我们发现温度系数越大,其损失越大了,原因和之前分析的一样,由于指数函数的非线性造成的。此时反而应该适当增大温度系数才能让模型继续学习下去,此时我们能认为模型已经『学习』到较为简单的表征了,需要加大训练难度才能让模型学习到更为细致(或者困难)的表征,此时我们将温度系数调大,可以认为是加大了整个任务的『难度』。总得来说,在训练不同阶段通过调节温度系数的大小可以控制整个任务的难度,从而让模型学习出不同粒度的表征,可以看成是一种『coarse to fine,从粗到细』的方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# logit: [0.7  0.1  0.05 0.15]
# np.exp(logit/t): [1.20260428e+06 7.38905610e+00 2.71828183e+00 2.00855369e+01]
# 输出 loss: 2.5105927394272577e-05
====
# logit: [0.7 0.1 0.05 0.15]
# np.exp(logit/t): [4.05519997 1.22140276 1.10517092 1.34985881]
# 输出 loss: 0.6453200240879728
====
# logit: [0.7 0.1 0.05 0.15]
# np.exp(logit/t): [2.01375271 1.10517092 1.0512711 1.16183424]
# 输出 loss: 0.9737318345317211
====
# logit: [0.7 0.1 0.05 0.15]
# np.exp(logit/t): [1.41906755 1.0512711 1.02531512 1.07788415]
# 输出 loss: 1.1702870665310683
====

然而手动在不同阶段调节温度系数很麻烦,而且可能并不是最优的,因此就有人尝试将温度系数设置为可学习的(learnable temperature),通过梯度更新温度系数。这一点在全精度训练下并没有太大问题,但是为了加大训练的batch size,通常会在半精度场景中进行训练(auto mixed precision,amp),此时容易出现下溢出的问题,损失函数如Fig 2.所示,训练过程会突然崩溃,此来龙去脉且听笔者道来。

Fig 2. 在加入了带可学习的温度系数损失函数后进行从头训练,会经常出现不稳定的现象。

我们以paddle的半精度训练为例子,如文档[3]所示,半精度即是FP16。如Fig 3.所示,此时实际情况中模型训练中的某些变量, 比如grad (特别是 activationgrad), 可能会因小于 FP16的精度低而变成0 (红色线条以左); 另一方面在FP16 的表示范围的中有很大的一部分(从蓝色线最大值往右) 却没有被利用到,导致整个分布(绿色方块)只有一小部分(红线-蓝线之间)落在FP16的表示范围内。此时对梯度的数值进行整体的一个放大(或者缩小),能够更充分的利用FP16 的表示范围。AMP 会在反向开始前对 loss 进行放大,并在执行任何梯度相关操作(e.g. gradient-clip, update) 之前对 gredient 进行缩放恢复原来的大小。

通常来说,框架会提供动态损失放缩(Dynamic loss scaling)的机制,此时框架会检测scale大小是否合适,当scaling up 不足时,仍会有部分较小变量会被表示成 0而损失精度;当scaling up 过度时,变量超过FP16表示范围出现 nan or inf, 同样造成精度损失。此时动态损失放缩会采用自动检测梯度值的方法:

当连续incr_every_n_steps(int)个batch 中所有的gradient 都在FP16 的表示范围, 将scaling factor 增大incr_ratio(float)倍;当有连续decr_every_n_nan_or_inf(int)个batch 中gradient 里出现 nan / inf时, scaling factor 缩小 decr_ratio(float)倍.

如Fig 4.所示,在 Dynamic loss scaling 中,框架在每一个 iteration 都会依据当前 gradients 是否出现 nan or inf 还有用户设置的 Dynamic loss scaling 参数来动态调整 loss scaling factor 的大小,将gradient 尽量保持在 FP16 的表示范围之内。

Fig 3. 如果不进行loss scale,那么会有很大一部分的数值在FP16中溢出。

Fig 4. 正常情况下loss scale是波动上下变化的,框架会自动检测并且调整。

然而,如果在半精度场景中把温度系数设置为可学习的,通过之前的讨论我们知道温度系数控制了整个学习任务的难易程度,同时其也可以控制损失大小,也就意味着在训练到某个阶段,模型可能会倾向于连续降低温度系数以进一步降低损失(此时继续优化模型参数本身需要更多的iteration,但是『优化』温度系数会获得短期更快的损失下降)。此时温度系数会持续下降,如Fig 5.右图所示,而温度系数一直下降意味着的值会迅速增大,而其指数函数增大得更快,导致loss scale持续下降(否则就超过了FP16的表示范围),知道loss scale降到了0都无法挽救为止,此时训练将会出现nan或者inf,导致训练过程崩塌。

Fig 5. loss scale和可学习温度系数的大小变化曲线图。

怎么缓解这个问题呢?笔者实践中是将温度系数固定在0.05开始模型从头的训练,等模型训练到差不到的时候才开放温度系数的可学习性。

补充

在知乎有个朋友问了个有意思的问题,感觉值得写个补充回答,因而在博客这里回答下,知乎链接:https://zhuanlan.zhihu.com/p/526175260,问题为:

serendipity: 根据文章所述,训练崩溃的原因是:网络在优化的时候倾向于缩小温度系数,然而这会导致太大,发生梯度爆炸,即使 loss_scale变成0也拉不回fp16的表示范围。如果是这样的话,即使用fp32也会遇到训练崩溃的问题吧,缩小温度系数导致梯度爆炸。

首先注意到,loss scale是在混合精度训练过程中才存在的概念,这个用于将值的表示范围平衡到fp16的有效范围内,而在fp32中是没有这个概念的。因此这里的逻辑链条其实应该是,由于在fp16的情况下,温度系数倾向于缩小(但不会一直缩小,按我的经验应该是会下降到0.05左右的量级),此时的数值很大(对于FP16的表示范围而言),导致梯度会变得很大,有一个瞬间梯度将会超出了fp16的表示范围从而使得loss scale一直下降。我们从IEEE 754中可知道FP16的表示范围在,FP32的表示范围在,如Fig A1所示。

Fig A1. IEEE 754规定的浮点数表示范围。

但是问题的关键并不在此,问题在于如Fig A2所示,在大于30000步后的某个瞬间,温度系数将会『骤降』到0,此时为Inf,因此loss scale也随着『骤降』到0了。如何理解这个温度系数的骤降呢?笔者个人是这样理解的,注意到混合精度训练机制下,是对梯度进行检测是否在有效FP16范围内,从而对loss进行放缩的,而其中间的计算值(如激活值)是不会被检测的。同时,如Fig A3所示,在混合精度训练过程中,中间计算结果(如激活值),反向的梯度等是用FP16表示的,也就是说一旦温度系数太小,导致超过的时候,就会让其溢出变成inf,由于loss scale是通过检测梯度变化才会改变的,此时loss scale对其是无能为力的。这里出现的inf将会导致梯度出现inf,使得loss scale缩小decr_ratio(float)倍,然而即便如此,在训练的过程中仍然会不断持续出现inf(主因是FP16的表示范围太小了,很容易就会超出65000的上限),因此loss scale一直缩小decr_ratio(float)倍,极其快速地达到了0(下溢出了)。注意到一旦loss scale变成0,就再也无法脱离0这个命运了,因为无论0乘上何数都为0。不过笔者暂时对为何此处温度系数也会突然骤降到0还不是很理解,上面的解释暂时没法解释这个现象。

正常的FP32的训练过程如Fig A4所示,FP32的表示范围比FP16大太多了(上限高了30多个数量级),其上溢出的风险远远小于FP16,因此也不存在loss scale这种妥协机制。在FP32下训练,其温度系数会逐渐收敛到一个值(比如是0.05左右),而不会一直下降到0,同样此时loss和metric都不会崩溃,能够正常训练。

Fig A2. 温度系数和loss scale在某个瞬间会骤降为0,导致了这场灾难。

Fig A3. 在AMP 中, 模型参数 weight , 前向中间的结果activation, 反向的gradient 都以FP16 形式存储, 由此可以减少模型占用的显存空间,同时提高计算和通信速度,也就是使得训练吞吐更大,训练更快. Paddle框架会为每一个weight 维护一个FP32副本, 用于参数更新。

Fig A4. FP32情况下的训练情况,分别是loss,metric和温度系数的曲线图,由于FP32下无loss scale这个概念,因此不需要绘制loss scale图(理论上是个常数)。

Reference

[1]. https://blog.csdn.net/LoseInVain/article/details/119515146, 《MoCo 动量对比学习——一种维护超大负样本训练的框架》

[2]. https://blog.csdn.net/LoseInVain/article/details/119516894, 《CLIP-对比图文多模态预训练的读后感》

[3]. https://fleet-x.readthedocs.io/en/stable/paddle_fleet_rst/fleet_collective_training_speedup_with_amp_cn.html, 《自动混合精度练加速分布式训练》