개요
제목: Unsupervised Learning of Visual Features by Contrasting Cluster Assignments
논문 링크: https://arxiv.org/abs/2006.09882
참고하면 좋은 이전 글:
Semi-supervised learning 기법 중 하나. 대부분의 contrastive learning 기법들은 많은 resource가 있어야 학습 가능하다. 보통 large batch size를 사용하거나 (SimCLR) momentum network (MoCo같은 기법)이 필요하다. 이 논문에서는 memory efficient한 방법으로 데이터를 clustering하는 기법을 제시한다. 방법은 간단하고 효율적아며, 당시 sota를 찍기도 했다. 아직까지는 여전히 유용한 기법인 것 같다.
방법
이 논문에서 두 가지의 좋은 아이디어를 제시한다:
- swapped prediction이라는 학습 task
- multicrop이라는 augmentation 방식
자세한 설명은 뒤에 하겠다.
Overview
- 목표는 좋은 visual representation을 supervision 없이 online fashion으로 얻는 것.
- SeLa, DeepCluster같은 방법들은 label을 얻는 step과 (offline) neural net을 학습하는 스텝으로 (online)구분되어 있다
- Motivation은 InstDisc 같은 contrastive learning 방식에서 얻었지만 다음과 같이 다르다:
- Contrastive learning은 이미지를 다른 방식으로 augmentation하고 neural net에 통과시킨 feature을 비교한다 (살짝 스킵되었지만 negative sample도 있다)
- SwAV에서는 이미지를 다른 방식으로 augmentation하고 neural net에 통과시킨 feature을 바로 비교하지 않는다. 이렇게 얻은 feature을 prototype c라는 것들을 사용해서 code Q로 바꿔주고, feature z1를 사용해서 code Q2를 예측하고, feature z2를 사용해서 code Q1을 예측하는 방식으로 학습한다 (swapped prediction):
\[ L(z_t,z_s) = \ell(z_t,q_s) + \ell(z_s,q_t) \quad (1)\]
더 자세한 내용은 나중에 설명
Notation
- \( x_n\): image
- \( t \ \in \mathcal{T}\): augmentation할 때 쓰는 sampled transformation
- \( x_{nt}\): transformed image
- \( f_\theta\): neural net (ResNet backbone + MLP projection head 라고 생각하면 된다)
- \( z_{nt} = \frac{f_\theta(x_{nt})}{||f_\theta(x_{nt})||_2}\): normalized feature
- \( q_{nt}\): \( z_{nt}\)를 K 개의 학습 가능한 prototype \( \{ c_1, ..., c_K\}\)로 mapping해서 얻은 code. 얻는 방법은 뒤에 설명.
- \( C\): \( \{ c_1, ..., c_K\}\)를 column으로 갖는 행렬
Swapped prediction problem
- Eq (1)가 swapped prediction problem이다
- Eq (1)에 등장하는 \( \ell\)는 code q와 z&c로 부터 얻은 확률 분포 사이의 cross entropy loss:
\[ \ell(z_t,q_s) = -\sum_k q_s^{(k)} \log p_t^{(k)}, \textrm{where} ~ p_t^{(k)} = \frac{\exp(\frac{1}{\tau} z_t^T c_k)}{\sum_{k'} \exp(\frac{1}{\tau} z_t^T c_{k'})} \quad (2)\]
- \( \tau\)는 temperature parameter
- 모든 이미지와 data pair에 대해서 더하면 다음과 같은 loss를 얻을 수 있다
- 보다시피 이 식에는 \( \sum_{k} q_{ns}=1\)이라는 가정이 녹아들어 있다.
- 이 loss를 최소화하기 위해서 prototypes C와 neural network의 \( f_\theta\)의 parameter \( \theta\)를 학습한다
- SeLa와 연관을 짓자면, SeLa에서는 \( \frac{1}{N}\sum_{N}\ell(z_t,q_t)\)을 최소화하는 것으로 생각하면 되고, SwAV에서는 \( \frac{1}{N}\sum_{N}(\ell(z_t,q_s ) + \ell(z_s,q_t))\)를 online으로 최소화하는 것으로 생각하면 된다 (augmentation하는 기법에서도 차이점이 있기는 하다).
Online code prediction
이제 q를 계산하는 방법을 소개하겠다
\[ \max_{Q\in \mathcal{Q}} Tr(Q^T C^T Z) + \epsilon H(Q) \quad (3) \]
- 처음 보면 잘 이해가 안 될 수도 있지만 SeLa 를 이해하고 있으면 사실 어렵지 않다.
- Notation:
- \( Z = [z_1, z_2, ..., z_B]\)는 B개의(batch size) feature들
- \( C = [c_1, c_2, ..., c_K]\): K개의 prototypes.
- \( Q = [q_1, q_2, ..., q_B]\), 각 feature에 해당되는 code
- 논문에는 설명이 생략되어 있지만 일단 BQ (B곱하기 Q) 는 batch에 있는 각 data가 C가 나타내는 K개의 class에 해당되는 label이라고 생각하면 된다 (예를 들어 \(q_i=(1/B,0,...,0)\) 이면 \( z_i\)는 \( c_1\)가 나타내는 class에 해당된다는 해석).
- 이제 첫 번째 항을 보자. max를 무시하면 \( \sum_{n \in \mathcal{B}}z^T_n C^T q_n\)과 같은데, Swapped prediction problem 섹션에서 소개한 (식 2 밑에 있는 ) loss의 첫 번째 항을 한 batch에 대해서 (\(\mathcal{B} \) = batch) 계산한 것과 '거의' 같다 (Swapped prediction problem 섹션에 있는 loss의 부호를 바꾸고, data 전체에 대해서 평균이 (1/N)아니라 batch에 대해서 (1/B)평균으로 해석하면 된다; 그리고 1/B는 Q로 흡수해야 한다; 그래서 BQ가 label에 해당된다).
- 이제 max를 고려하자. Swapped prediction problem 섹션에서 소개한 loss의세 번째 항은 이 max에 영향을 주지 않기 때문에 사실 포함해도 상관이 없다 (Q와 무관하기 때문에). 따라서 식 3의 첫 번째 항을 최대화하는 것은 batch에 대해서 식 (2)를 더한 값을 최소화시키는 q를 찾는 문제다. 다시 말해서 z (feature)들이 K개의 class에 해당하는 확률과 가장 잘 대응되는 label q를 찾는 문제다.
- 이때까지 q가 어떤 feature (다시 말해서 이미지)가 어떤 prototype c에 대응되는지 나타내는 label이라고 해석했다. 이렇게 하면 모든 feature들이 같은 class로 label 되도록 학습될 가능성이 있다. 이런 것을 피하기 위해서 SeLa에서는 feature들이 K개의 class에 골고루 assign 되도록 제약을 두는데, SwAV에서도 비슷하게 한다 (SeLa에서는 전체 dataset에 대해서 골고루 assign 되도록 하는데, SwAV에서는 batch에 대해서만). 다시 말해서 Q의 row에 대한 합이 각각 \(\frac{1}{B} \frac{B}{K}\) 가 되도록 제약을 준다 (\( q_i\)는 label/B이니까). 또한 q가 0,1 값만 갖도록 제약을 주지 않고, q의 component의 합이 1/B이 되도록 제약을 준다:
\[ \mathcal{Q} = \{ Q \in \mathbb{R}^{K\times B}_+ | Q1_B = \frac{1_K}{K}, Q^T 1_K = \frac{1_B}{B}\} \quad (4) \]
- 이런 제약을 줬을 때 식 (3)의 첫 번째 항은 optimal transport distance / earth mover's distance가 되는데 푸는 것이 까다롭기도 하고 discrete solution만 나오는 것으로 알려져 있다 (참고: 2023.02.18 - [paper review] - [논문 리뷰] Sinkhorn Distances: Lightspeed Computation of Optimal Transport) 하지만 \( \epsilon H(Q)=-\epsilon \sum_{ij}Q_{ij} \log Q_{ij} \)로 regularization을 주면 Sinkhorn-Knopp 알고리즘으로 빠르게 풀 수 있고, solution도 regularize된다.
SeLa에서는 식 (3)의 solution에서 label q들을 discrete하게 바꿔서 사용한다 (각 q에서 가장 확률이 높은 index를 선택해서 classification label로 사용). 하지만 SwAV에서는 이렇게 discretize하지 않고 Q를 그대로 사용한다 (더 좋은 성능을 얻었다고 한다).
따라서 SwAV의 아이디어는 다음과 같다: batch에 있는 이미지를 두 가지 방식으로 augmentation을 해서 \( z_s,z_t\)를 얻는다. 그리고 C를 사용해서 \( q_s,q_t\)를 식 (3)을 풀어서 얻는다. 그리고 \(z_s\) 를 사용해서 \( q_t\)를 predict하고 \(z_t\) 를 사용해서 \( q_s\)를 predict하면서 C와 neural network를 같이 학습시킨다.
Batch size 관련
위에서 설명했듯이 batch 안에 있는 feature들을 K개의 prototype로 균등하게 label을 배분하고, 이 label을 학습하도록 알고리즘이 설계되어 있다. 따라서 batch size가 작으면 분배가 잘 안 될 수 있다. 이러한 경우 first in first out queue를 사용해서 해결한다. 예를 들어 batch size B = 4096을 사용하고 싶지만 현실적으로 B=256밖에 사용하지 못한다고 하자. 이런 경우 4096-256 = 3840 크기의 queue를 만든다. 이 queue에는 feature들을 저장하고, 각 step마다 256 개의 feature들을 계산하고, 새로 계산한 feature(gradient 계산)와 예전에 계산했던 feature(queue에 있는 feature; gradient 계산 x)를 사용해서 Q를 계산한다. 이후는 이전에 설명한 방식대로 학습을 하고 queue를 업데이트한다.
Multi-crop augmentation
이미지에서 random crop을 하고, 서로 비교를 하는 방식이 큰 도움이 되지만 crop하는 수가 증가하면 memory 사용량이 감당이 안될 수 있다 (quadratic increase). 이것을 해결하기 위해서 multi-crop을 제안하는데 다음과 같다. 우선 이미지에서 2개의 crop은 비교적 크게 자르고, V개의 작은 crop을 한다. 그리고 다음과 같은 loss를 계산한다:
\[ L(z_{t_1},z_{t_2},...,z_{t_{V+2}}) = \sum_{i \in \{1,2\}} \sum_{v=1}^{V+2} 1_{v\neq i} \ell(z_{t_v},q_{t_i}) \quad (6) \]
여기서 \( t_1,t_2\)는 두 개의 큰 사이즈 crop에 해당하고 나머지가 작은 사이즈 crop이다. 보다시피 큰 사이즈 crop만으로 q를 계산한다.
참고로 큰 사이즈 crop이란 비교적 큰 영역을 (0.14~1 배 사이즈) crop하고 224로 resize한 이미지를 의미하고 작은 사이즈 crop은 (0.05~0.14 배 사이즈) 비교적 작은 영역을 crop해서 96로 resize한 이미지로 생각하면 된다.
참고로 실제로 crop한 후 horizontal flips, color distortion and Gaussian blur을 사용한다.
학습 방식
LARS optimizer 사용, cosine learning rate, MLP projection head 사용. 학습 디테일은 논문 참고
결과
SwAV로 ImageNet에 pre-training한 모델이 linear classification & object detection에 대해서 ImageNet 에 supervised training보다 더 좋은 결과를 얻었다:
스킵한 내용
SeLa-v2, DeepCluster-v2, SimCLR에 multi-crop을 적용하는 내용은 SwAV method와 직접적인 연관이 없어서 이번 리뷰에서는 스킵하겠다. 논문의 Appendix를 찾아보면 자세한 내용이 있다.
'논문 리뷰 > self-supervised learning' 카테고리의 다른 글
[논문 리뷰] Barlow Twins (0) | 2023.03.17 |
---|---|
[논문 리뷰] SimSiam (0) | 2023.03.13 |
[논문 리뷰] SeLa (0) | 2023.03.04 |
[논문 리뷰] BYOL (0) | 2023.03.01 |
[논문 리뷰] DeepCluster (Deep Clustering for Unsupervised Learning of Visual Features) (1) | 2023.02.26 |