SwAV
前言
与前面的一些工作不同,SwAV不再进行个体判别任务,而是提出了新的任务 ————聚类
并在训练的模型结构上也做了相应改动,而非只调整训练方法。
这里因为涉及到了聚类,具体数学推导难度较大,有兴趣可以跟着我后面贴的视频走一遍
提出的背景
个体判别任务的局限性
1. 进行个体判别任务时,我们有可能选到与正样本属于一个类的特征作为负样本,比如两张不同狗的图片等等,这会影响学习的效率。
2. 个体判别任务把类别分的过细,不能很好地抽取特征,过多的类也增大了计算压力和字典储存压力。作者认为这过于原始和暴力。
模型结构
下图左边是常规的对比学习(比如 SimCLR)的结构,右图是 SWAV 的结构,不难看出多了一个叫 prototypes 的东西。这个东西其实是聚类中心向量所构成的矩阵。
下面的内容可能有些理解上的难度(反正我第一次听讲解的时候就云里雾里的),我会尽可能直白地描述这个过程。
聚类中心?
首先我们有个新的东西prototypes,它是聚类中心的集合,也就是许多作为聚类中心的向量构成的矩阵。
这些聚类中心是我设定在超球面上的,离散的一些点,我希望让不同的特征向它们靠拢以进行区分(也就是所谓聚类)。
更直白地讲,我在地上撒了一把面包屑,地上本来散乱的蚂蚁会向面包屑聚集,形成一个个小团体。蚂蚁就是不同图像的特征,面包屑就是我设定的聚类中心
聚类中心我知道了,然后呢?
先说我拿他干了什么,再一步步讲为什么要这么做吧。
首先我们手里有抽取出来的特征z1,z2,以及一个我随机初始化的聚类中心矩阵 c。我分别求这个矩阵和z1,z2的内积,并进行一些变换得到 Q1,Q2。当 z1,z2 都是正样本时,我希望Q1 与 z2 相近,Q2 与 z1 相近。如果有一个是负样本则尽可能远离。也就是拿 Q 当 ground-truth 做训练。最后这步前面已经讲过 NCEloss 等损失函数了,用它们就可以达成这个任务。
而我们的优化要采用 K-means (不懂可以看这里) 的类似做法,先对聚类中心进行优化,再对特征进行优化。
so,why?相信你现在肯定是一脸懵,不过别急,希望我能为你讲懂。
首先是第一步,为什么要求内积?
如果你有好好了解线性代数的几何性质,应当了解两个向量的内积就是一个向量在另一个向量上的投影,而一个向量与一个矩阵的内积,就是把这个向量投影到这个矩阵代表的基空间中。
我做的第一步就是把抽出来的特征 z 用聚类中心的向量表示,这样更加方便对比聚类成功与否。
然后是第二步,我说的变换是什么呢?
我们现在求内积是为了把特征投影到聚类中心空间,为了避免模型训练坍塌(就是网络把特征全部聚到同一个点,开摆~)我要保证每个聚类中心被"使用"的次数,所以我们请出了Sinkhorn-Knopp 算法。这个算法比较硬核,我在这里不展开了,大家知道它是干啥的就行,具体的推导可以看我后面贴的视频,那里面有讲。
第三步应该不用怎么讲了吧?
就是普通的对比学习,也没啥特殊的了。正样本是自身通过数据增强和上面两步处理得到的特征,负样本则是同一 batch 中的其他特征。
数据增强的小 trick:multi-crop
其实就是一个工程经验,一般我们数据增强是取 2 个 224*224 的块,这里换成了面积基本不变的 2 大 2 小的 4 个块,事实证明效果不错。想了解这个的话可以看看原论文的实验
总结
主要贡献是上面我说的三步聚类算法以及后面的小 trick,Sinkhorn-Knopp 算法难度较高,大家有兴趣的话自行观看后面这个视频理解哈~
相关资料
【[论文简析] SwAV: Swapping Assignments between multiple Views [2006.09882]】