SigLIP——采用sigmoid损失的图文预训练方式

CLIP中的infoNCE损失是一种对比性损失,在SigLIP这个工作中,作者提出采用非对比性的sigmoid损失,能够更高效地进行图文预训练...

FesianXu 20240825 at Wechat Search Team

前言

CLIP中的infoNCE损失是一种对比性损失,在SigLIP这个工作中,作者提出采用非对比性的sigmoid损失,能够更高效地进行图文预训练,本文进行介绍。如有谬误请见谅并联系指出,本文遵守CC 4.0 BY-SA版权协议,转载请联系作者并注明出处,谢谢

  • 关键字: 高效图文预训练、sigmoid损失取代softmax损失
  • paper发表信息:CVPR 2023

联系方式:

e-mail: FesianXu@gmail.com

github: https://github.com/FesianXu

知乎专栏: 计算机视觉/计算机图形理论与应用

微信公众号:机器学习杂货铺3号店


基于对比学习的图文预训练方式,自从CLIP [1] 横空出世后,就成为了图文预训练的主流方式,引申出了一系列的工作,如ALIGN [3]、FLIP [4]、LiT [5]等。这些工作在数据使用、训练效率等上进行了探索,但是其核心损失还是采用了infoNCE,也即是对比型的损失。在SigLIP [2] 中,作者提出了基于sigmoid损失的图文预训练方式,并且指出采用sigmoid损失能带来更高效的图文预训练效率和效果。在此之前,我们有必要再复习下CLIP的基本思想。CLIP是一个双塔结构,分别有图片塔和文本塔,那么损失可以表达为式子(1),其中的,是图片特征和文本特征的L2 normalization后的结果,为温度系数,其中为可学习参数,为一个批次(batch)的数据。 其基本思想就是从打分矩阵中,从i->tt->i的方向去判断出正样本的位置(也就是对角线的位置),注意到由于采用的是softmax形式去归一化正负样本,将其视为了概率分布,因此正负样本之间的概率关系是耦合在一起的,在提高正样本概率的同时,势必会压低负样本的概率。

fig-clip-frame

Fig 1. CLIP的基本结构由图片塔和文本塔组成,打分矩阵的对角线为正样本,从i2t和t2i的方向分别计算infoNCE损失。

而在SigLIP中,损失函数为式子(2)所示,其中的为给定图片文本对的标签,当为成对的正样本时候,当不是成对的负样本时候。此时对于正负样本来说是解耦的,增加正样本的概率并不意味着压低负样本的概率。负样本数量的绝对占优,会导致在训练初期负样本的损失主导了整个损失,因此引入了一个可学习的偏置项去缓解初始阶段的训练困难问题,此处的在原文中被初始化为-10,这也容易理解,初始的logit减去一个较大的值(如-10),使得正负样本logit的差别相对不会很大,在正负样本数量极大不均匀的情况下,可以让初始状态更加均匀,从而不会带来过度调整。

整个损失的建模,如以下代码所示:

1
2
3
4
5
6
7
8
9
10
# img_emb : image model embedding [n, dim]
# txt_emb : text model embedding [n, dim]
# t_prime, b : learnable temperature and bias
# n : mini-batch size
t = exp(t_prime)
zimg = l2_normalize(img_emb)
ztxt = l2_normalize(txt_emb)
logits = dot(zimg, ztxt.T) * t + b
labels = 2 * eye(n) - ones(n) # -1 with diagonal 1
l = -sum(log_sigmoid(labels * logits)) / n

这个就是SigLIP的核心优化点,我们先不考虑这个建模的模型效果,先看到这种建模方式带来的模型训练的优势。

  • 在CLIP中,正负样本是pairwise建模的:由于采用的softmax函数去建模正负样本之间的关系,而CLIP训练的global batch size一般都很大(如32k),这意味着GPU #1上的正样本需要见到其他所有GPU上的样本,并以此作为负样本。因此通常都需要汇聚多节点多卡的特征向量,这个时候需要频繁地调用all_gather,带来了沉重的通信代价,也会拖慢整个训练过程的速度。
  • 在SigLIP中,正负样本是pointwise建模的:采用的sigmoid loss是独立对每个正负样本进行计算的,最后再进行loss的累加,这意味着可以在本地完成大部分的计算,在涉及到本地的正样本和其他设备的负样本进行交互计算的时候,仅需要很少的gather操作就能完成设备间向量的交换就可以(对于图文预训练来说,交换文本特征向量即可,通信代价很低),而不需要all_gather操作。

我们着重介绍下SigLIP是如何进行分布式训练的,假设全局的batch size为,一共有个GPU,那么每个GPU上的batch size为,可以将公式(2)的损失拆解为公式(3)所示,在Fig 2. 展示了整个过程的示意图,在初始化阶段,我们以第一个GPU为例子,其所包含的样本为: 此时GPU 1可以完成一次公式(3)中的C计算,然后,交换GPU 1和GPU 2的文本编码器特征向量,既是: 此时GPU 2完成一次公式(3)中的B计算,以此类推,直到GPU 1遍历完所有样本为止,其他GPU也是如此操作的,最终把所有卡上的损失汇聚即可,也就是A计算。这个轮流交换不同GPU之间数据的操作,可以称之为permutation。不难发现,整个过程的通信成本来自于permutation,一共需要gather操作即可完成一次permutation,而在CLIP中需要对图文的编码器特征都进行汇聚,因此需要2次all-gather操作。如果all-gather采用ring的话,那么一个all-gather就是gather操作。由此我们得出一个SigLIP和CLIP的性能复杂度对比:

模型通信复杂度单卡储存复杂度单卡计算复杂度(计算的数量)单卡计算复杂度(计算的数量)
CLIP
SigLIP

容易发现,SigLIP无论从通信复杂度,储存复杂度还是计算复杂度上,都远比CLIP更为优越。

fig-siglip-efficient-implement

Fig 2. SigLIP高效的损失计算示意图,假设有3个设备,每个设备上的batch size为4,global batch size为12。

让我们再关注到SigLIP的模型能力表现,作者主要对比的是SigLIP,以及将图片表征固定的SigLiT(从而可以将batch size设置到非常大,比如100万)以及CLIP的表现。我们都知道在CLIP中采用对比损失,意味着越大的batch size能极大地提高对比效率,从而提升效果,受限于softmax的内存占用情况和GPU卡数等原因,无法将batch size设置得很大,在SigLiT中则可以将batch size设置到百万以上,从而探索极大batch size情况下的收益。如Fig 3.所示,作者对比了三种模型在batch size进行尺度放大后的0-shot能力,训练量都是18B的数据量,容易发现几点结论:

  1. 在batch size小于32k的时候,采用sigmoid的SigLIP的性能都会优于采用softmax的CLIP。
  2. 在batch size足够大的时候,CLIP能够追上,甚至超越SigLIP的表现,但是最佳性能仍然是SigLIP@32k情况下得到,从实际应用来看,采用SigLIP能用更少的资源更快的训练速度得到更好的性能。
  3. 从SigLiT的实验来看,随着batch size的尺度放大性能将会在32k batch size的情况下达到饱和,同时能观察到SigLiT在不同batch size下持续优于LiT。继续提高batch size将不能带来更好的收益,甚至会有细微的性能下降。

fig-sig-clip-lit-result

Fig 3. SigLiT、SigLIP和CLIP在batch size进行尺度放大情况下的0-shot性能对比。

超大的batch size是否需要更长的训练量支持?作者在SigLiT的设置下,训练了更长时间(也即是见了更多数据量),如Fig 4.所示,在超大batch size,如262k的情况下,对比较小batch size(如8k)提供更长的训练时间的确能观察到性能的较大提升。并且也能观察到在超大batch size下,采用sigmoid和采用softmax的区别很小,但是在较小batch size(如8k)的情况下,则差距明显。因此,在资源受限的情况下,采用SigLIP是很划算的,所需资源少而且性能更强。同时,这个试验也说明了,超大的batch size并不意味着训练得更快,反而还需要更长的训练时间。

fig-longer-training

Fig 4. 扩大了见过的数据量后,越大的batch size能带来较为明显的性能提升,同时,可以观察到在超大batch size下,采用sigmoid和采用softmax的区别很小,但是在较小batch size(如8k)的情况下,则差距明显。

除了batch size的影响外,作者还探索了很多有趣的点,包括SigLIP在多语言数据集上的表现、大尺度batch size下的训练稳定性问题、训练中负样本比例的影响、sigmoid训练的鲁棒性探索等问题。在多语言数据集上,作者发现性能同样在32k batch size上达到了饱和,其他细节就不累述了,感兴趣的读者自行翻阅。

笔者比较感兴趣的是其他问题,比如在大尺度batch size下,训练容易出现不稳定的情况,这个原因在于在训练过程中,gradient norm会出现大幅度的尖峰,如Fig 5. 所示,这导致了参数和训练损失的不稳定(也即是尖峰),作者观察到,如果将Adam优化器的值从0.999下降到0.95,那么训练过程就会稳定下来。

fig-training-stability

Fig 5. 大尺度batch size下训练容易出现不稳定的情况。

从公式(2)中,注意到SigLIP是对所有正负样本的pair进行计算损失然后累加求和的,这意味着可以从中剔除掉负样本以控制负样本的比例。对于batch size为的一次损失计算而言,其中有个正样本,有个负样本,负样本其实在后期很多都是简单负样本,是否可以剔除简单负样本是一个值得探究的问题。作者提出了几种消融试验,去挑选负样本,从而控制正负样本比例:

  • 随机:随机挑选负样本对,并且对其进行剔除。
  • 难负样本:把难负样本保留下来,即是通过将最高打分的topk负样本保留下来。
  • 简单负样本:把简单负样本保留下来,即是将打分最低的lowk负样本保留下来。
  • 难负样本+对齐训练量:保留难负样本的同时,提高训练step数量以对齐训练数据量。

实验结果如Fig 6.所示,其中的横坐标为一个batch内的正样本数量:负样本数量,其中的1:16k则是不进行任何负样本剔除的基线,从实验中可以得出几个结论:

  • 只保留简单负样本,会使得模型性能完全崩溃。
  • 随机剔除负样本,也会损失模型性能。
  • 只保留难负样本,对模型性能的损失是最小的,在对齐了训练数据量后(因为剔除了负样本,同个step下模型讲过的数据对数量就少了,因此需要训练更多step去对齐训练数据量),性能甚至还能比基线更好,这说明了难负样本才是最有价值的,怎么去合理地挑选难负样本是提高模型性能的关键因素
  • 再看到随着负样本数量的减少,可学习偏置的值和正负样本的平均logit值都在递增,这也是符合预期的。有趣的一点是,当采用难负例保留的策略中,随着负样本数量逐渐减少,正负例的logit区分度在加速减少,并且正例的logit变化基本上是平坦的,这个现象和随机丢弃的策略是不同的。对此的解释是,本来难负样本和正样本就比较接近,在减少了负样本数量,只保留最难的负样本后,负样本的logit值就加速地上涨,从而导致了区分度减低的情况,这也是符合预期的。

fig-effect-batch-composition

Fig 6. 采用不同策略控制损失中的正负样本比例的效果对比。

前文已经提到了sigmoid和softmax的区别在于,前者解耦了正负样本的概率关系,这使得即便负样本中出现假阴性样本,也只会影响自己的损失,而不会影响到其他样本,因此这带来了数据的健壮性。作者也进行了对应的试验,如Fig 7.所示,作者对数据中的图片、文本进行随机加噪、对batch内的图文对进行随机打乱、或者将上面的加噪方式都进行组合,发现基于sigmoid的训练过程,总是比基于softmax的训练过程更加鲁棒。

fig-corruption-robust

Fig 7. 基于sigmoid的训练能够提高训练的健壮性,对数据中的噪声更为鲁棒。

总的来说,SigLIP是一个很棒的工作,作者采用sigmoid损失去取代对比学习中的softmax函数,以更小的资源开销带来了更好的模型表现,目前也被很多多模态大模型所采用,作为视觉端的编码器。

Reference

[1]. Radford, A., Kim, J. W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., ... & Sutskever, I. (2021, July). Learning transferable visual models from natural language supervision. In International Conference on Machine Learning (pp. 8748-8763). PMLR. aka CLIP

[2]. Zhai, Xiaohua, Basil Mustafa, Alexander Kolesnikov, and Lucas Beyer. "Sigmoid loss for language image pre-training." In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 11975-11986. 2023. aka SigLIP

[3]. Jia, C., Yang, Y., Xia, Y., Chen, Y. T., Parekh, Z., Pham, H., ... & Duerig, T. (2021, July). Scaling up visual and vision-language representation learning with noisy text supervision. In International Conference on Machine Learning (pp. 4904-4916). PMLR. Short for ALIGN

[4]. Li, Y., Fan, H., Hu, R., Feichtenhofer, C., & He, K. (2022). Scaling Language-Image Pre-training via Masking. arXiv preprint arXiv:2212.00794. aka FLIP

[5]. Zhai, X., Wang, X., Mustafa, B., Steiner, A., Keysers, D., Kolesnikov, A., & Beyer, L. (2022). Lit: Zero-shot transfer with locked-image text tuning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 18123-18133). aka LiT