논문 리뷰/semi-supervised learning

[논문 리뷰] ReMixMatch

curious_cat 2023. 4. 6. 01:27
728x90
728x90

개요

논문 링크: https://arxiv.org/abs/1911.09785 (REMIXMATCH: SEMI-SUPERVISED LEARNING WITH DISTRIBUTION ALIGNMENT AND AUGMENTATION ANCHORING)

이전 글:

Semi-supervised learning 알고리즘 중 하나다. 이전 논문인 MixMatch에 distribution alignment와 augmentation anchoring이라는 두 가지 기법을 추가해서 성능을 높인 논문이다.

  • Distribution alignment: unlabeled dataset의 pseudo-label distribution을 labeled dataset의 label distribution과 비슷하도록 유도하는 방법 (저자들은 mutual information을 사용해서 distribution alignment 하는 이유를 설명하지만, 생각해 보면 class balancing이랑 목표가 비슷하다).
  • Augmentation anchoring: 약한 augmentation으로 pseudo-label을 만들고 강한 augmentation을 준 이미지에 대해서 pseudo-label을 예측하게 하는 기법. 이것을 잘 하기 위해서 CTAugment라는 기법을 도입한다.

방법

Distribution alignment

  • Motivation: mutual information은 random variable x에 대해서 알면 random variable y에 대한 불확실성을 얼마나 줄일 수 있는지 나타내는 수치이다. unlabeled data를 잘 활용하는 방법 중 하나는 unlabeled image x에 대한 분포를 사용해서 label y에 대한 불확실성을 최대한 줄이는 것이다. 즉, mutual information \( I(y;x)\)를 최대화하면 된다는 뜻이다 (식 1은 정의; 식 2의 유도 과정은 논문 부록 참고):
    \[ \begin{align} \mathcal{I}(y;x) &= \int p(y,x) \log \frac{p(y,x)}{p(y)p(x)} dy dx \quad (1) \\ ~ &= \mathcal{H} (E_{x}[p_{model}(y|x;\theta)] - E_{x}[\mathcal{H}(p_{model}(y|x;\theta)]) \end{align} \quad (2)\]

    • \( p_{model}\): neural net, \( \theta\)는 모델 파라미터
    • \( E_{x}\)는 x에 대한 평균
    • 첫번째 항 해석: 데이터 전체에 대해서, 각 class로 라벨 되는 이미지의 수가 동일한 것을 선호 (equipartition)
    • 두 번째 항 해석: 각 이미지에 대해서 한 class만 1, 나머지 0을 선호 (confident prediction)
  • Distribution alignment 디테일: MixMatch에서 sharpening을 통해서 식 (2)의 두 번째 항과 비슷한 효과를 얻는다. 하지만 첫 번째 항에 대응되는 기법은 사용하지 않았다. 첫 번째 항처럼 unlabeled data에 대해서 equipartitioning을 하는 것은 유용하지 않을 수 있지만 다음과 같이 비슷한 아이디어를 유용하게 사용할 수 있다:
    • \( \tilde{p}(y)\): 학습 도중 모델이 unlabeled data에 대해서 class y를 예측하는 확률의 moving average라고 하자 (최신 128 batch에 대한 평균)
    • \( p(y)\)는 학습할 때 사용하는 class label 분포
    • \( q = p_{model}(y|u;\theta)\): unlabeled data u에 대해서 모델이 예측하는 확률
    • 다음과 같이 class에 대한 확률 분포를 바꿔준다: \( \tilde{q} \propto q \times \frac{p(y)}{\tilde{p}(y)}\). 이 분포가 확률 분포가 되려면 normalize해야 해서 다음과 같은 식을 얻을 수 있다:
      \[ \tilde{q} = \textrm{Normalize}(q \times p(y) / \tilde{p}(y))\]
    • \( q\) 대신 \( \tilde{q}\)를 사용해서 sharpening같은 후처리를 하고 pseudo-label로 사용
  • Distribution alignment 해석: \( \tilde{q}\)를 unlabeled data에 대해서 평균을 내면 \( p(y)\)가 된다. 즉, unlabeled data에 대한 pseudo-label을 labeled data의 distribution과 맞도록 balancing/alignment을 해주겠다는 뜻이다.

Augmentation Anchoring & CTAugment

  • Augmentation anchoring은 이해하기 쉽다: 이미지에 약한 augmentation을 줘서 (anchor 이미지) 모델에 통과시켰을 때 나오는 결과를 pseudo-label로 사용하고, 강한 augmentation을 준 K 개의 이미지들에 대해서 모델이 pseudo-label을 예측하게 하도록 학습하는 기법이다. 강한 augmentation을 잘하기 위해서 CTAugment를 사용한다.
  • CTAugment (control theory augment): data augmentation 하는 강도를 학습하면서 조절하는 방법. (data augmentation policy를 설정하는 방법들은 이전에도 제시가 되었지만, 본 논문 학습 세팅에서는 적합하지 않았다고 한다; 관심 있으면 논문 참고)
    • 우선 사용할 augmentation이랑 augmentation의 최소, 최대 강도를 정한다. 그리고 augmentation transformation의 강도를 조절하는 parameter을 N개의 bin으로 쪼갠다. 예를 들어 rotation augmentation을 할 때 회전 각도가 최소 -45도, 최대 45도이며 N=10이면, 각 bin은 -45, -35, -25, ..., 35, 45가 될 것이다. 실제로는 N=17 사용.
    • N개의 bin에 각각 weight \( m_i \)를 할당한다: \(m = (m_1,...m_N)\). 학습을 시작할 때 \( m_i=1\)로 초기화한다.
    • 각 학습 step마다 두 개의 augmentation transformation을 샘플링한다 (uniform sampling).
    • 그리고 \( \hat{m}\)를 다음과 같이 계산한다: \( \hat{m}_i = m_i\) if \( m_i>0.8\), \( \hat{m}_i=0\) otherwise. 이렇게 만든 \( \hat{m}\)을 Normalize해서 확률 분포로 만든 후 강도 bin i를 샘플링을 한다. 
    • weight를 업데이트는 다음과 같이 한다: 각 transformation마다 bin을 랜덤하게 (uniform distribution을 사용해서) 샘플링한다. 이렇게 얻은 강도를 사용해서 labeled image x에 augmentation을 해서 \( \hat{x}\)를 얻는다. 모델이 \( \hat{x}\)에 대한 예측이 정확하면 weight \( m_i\) 를 크게 해 주고, 부정확하게 예측하면 weight를 작게 한다.
      • 구체적으로 \( \omega = 1-\frac{1}{2L}\sum |p_{model}(y|\hat{x};\theta)-p|\). 여기서 p는 GT label, L은 number of classes. 
      • \( m_i = \rho m_i + (1-\rho) \omega\), \( \rho = 0.99\)

ReMixMatch

다음과 같이 \( \mathcal{X}', \mathcal{U}',\hat{\mathcal{U}}_1\)를 얻는다:

  • \( \mathcal{X}', \mathcal{U}'\)는 labeled/unlabeled data에 대해서 MixUp을 한 결과이며, label / pseudo-label을 사용해서 cross-entropy loss를 준다 (MixUp에서 하는 것처럼)
  • \( \hat{\mathcal{U}}_1\) (강하게 augment한 샘플; MixUp 사용 x)에 weak augmentation으로 얻은 guessed label로 cross entropy loss를 준다
  • 이미지를 rotate 하고 (r=0, 90, 180, 270도 중 하나) 모델에게 이미지 rotation 값을 예측하게 한다
  • 정리하자면 다음 loss function을 사용

추가 디테일

  • \( \lambda_r = \lambda_{\hat{\mathcal{U}}_1}=0.5 \)
  • T=0.5, Beta=0.75, \( \lambda_{\mathcal{U}} = 1.5\)
  • K=8
  • Adam, learning rate = 0.002, weight decay = 0.02
  • 학습 모델 parameter의 exponential moving average를 취해서 최종 모델을 얻음 (EMA decay rate = 0.999)

실험 결과

728x90
728x90