논문 리뷰/semi-supervised learning

[논문 리뷰] ShrinkMatch

curious_cat 2023. 8. 24. 22:56
728x90
728x90

개요

논문 링크: https://arxiv.org/pdf/2308.06777.pdf (Shrinking Class Space for Enhanced Certainty in Semi-Supervised Learning)

깃헙: https://github.com/LiheYoung/ShrinkMatch

이전 글:

요약:

  • (classification 관련) semi-supervised learning 에서 SOTA를 찍은 최신 논문이다.
  • FixMatch같은 방식들은 unlabeled data를 활용할 때 pseudo-label의 confidence가 낮으면 버리는 방식으로 학습을 하는데, 이렇게 하면 unlabeled data를 충분히 활용할 수가 없다 (버려지는 데이터가 있기 때문에). 
  • 이것을 해결하기 위해서 이 논문에서는 이렇게 버려지는 데이터를 간단하게 활용하는 방법을 제안한다.
  • 우선 confidence가 가장 높은 class (=top-1 class)에  unlabeled data의 class 정보가 잘 담겨있는 경향이 있다고 한다.
  • (밑에 그림 참고) 예를 들어서 unlabeled data가 tabby cat 사진인 경우 모델이 tiger cat, siamese cat같은 고양이들과 헷갈릴 가능성이 높지만, table, bed같은 물체들과는 확실하게 구분할 가능성이 높다.
  • 이러한 경우 고양이 종류가 헷갈려서 confidence가 낮을 수 있지만, tiger cat, siamese cat같이 헷갈릴 만한 class (confusion class)들을 제거한 class들만 가지고 보면 confidence가 높다.
  • 이렇게 축소된 (shrink된) class들을 가지고 pseudo label을 만들어서 사용하겠다는 것이 이 논문의 핵심 아이디어다

방법

FixMatch와 비슷한 부분

notation

  • \( D^l = \{ (x_k,y_k)\}\): labeled data (x: image, y: label)
  • \( D^u = \{ u_k\}\): unlabeled data
  • \( f \): classification 모델
  • \( \mathcal{A}^w\): weak augmentation
  • \( \mathcal{A}^s\): strong augmentation
  • \(u^w \): unlabeled data에 weak augmentation한 결과
  • \( u^s\): unlabeled data에 strong augmentation한 결과
  • \( p^w = f(u^w)\)
  • \( p^s = f(u^s)\)

Labeled data & confidence가 높은 unlabeled data는 FixMatch와 비슷하게 학습한다

  • consistency regularization: 약하게 augment한 데이터에서 얻은 \( p^w\)를 강하게 augment한 데이터에서 얻은 \( p^s\)의 GT 값으로 사용하겠다는 아이디어. 
  • FixMatch같은 경우 consistency regularization을 활용하는 방법 중  하나인데, 핵심 아이디어는 confidence가 threshold보다 높은 pseudo-label만 활용하겠다는 것이다.
  • unlabeled data에 대한 loss:

  • \( B^u\): unlabeled data의 batch size
  • \(\tau\): threshold
  • \( \xi(p^w_k)  = max(\sigma(p_k^w))\). 여기서 \( \sigma\) = softmax. 쉽게 말하면 top-1 class의 확률 값 = confidence 값.
  • \( H\) = cross entropy 
  • labeled data에 대한 loss: cross entropy
  • Full loss:

  • \( lambda_u\): hyperparameter

본 논문의 핵심 기여

shrunk class space는 어떻게 구하는가?

  • 이 논문의 핵심 아이디어는헷갈릴 만한 class들을 제거해서 축소된 (shrunk) class들만 봤을 때 top-1 class의 confidence는 높기 때문에, 이 shrunk space에서 pseudo-label을 활용하겠다는 것이다. 
  • 그러면 어떻게 헷갈리는 class들을 제거해서 shrunk space를 만드는 것이 좋을까?
  • 우선 weak augmentation된 이미지의 class 예측 확률 \( p^w\)를 구성하는 값들을 내림차순으로 나열한다: \( p^w = \{ s^w_{n_i}\}^C_{i=1}\)
    • \( s^w_{n_i}) 는 \(n_i \)번째 class에 해당하는 확률 값
    • \(i = 1, ..., C\)이고 \( n_i\)는 내림차순으로 정렬할 때 사용되는 index라고 생각하면 된다
    • 따라서 \( s^w_{n_i} \ge s^w_{n_{i+1}}\)
  • shrunk space는 다음과 같이 만든다
    • top-1 class는 포함시킨다 
    • confusion class \( \{ n_i\}_{i=2}^{K-1}\)와 확실히 아닌 class \( \{ n_i\}_{i=K}^C\)는 식 (3) & (4)를 통해서 통해서 구한다 (i=2,...,K-1가 헷갈리는 class, i=K,...,C까지가 확실히 아닌 class)
    • 식의 의미는, confusion class를 제외한 class 사이에서 top-1 class의 confidence (\( \xi\) = top-1 class의 confidence를 나타냄)가 특정 threshold보다 높아지는 최소한 K를 구하겠다는  것이다.

shrunk class space에서 어떻게 학습하는가?

  • 이전처럼 strong augmentation을 한 이미지에 대해서 weak augmentation을 GT처럼 사용한다 (consistency regularization)
  • weak augmentation을 한 이미지를 기반으로 찾은 shrunk class에서 확률을 다음과 같이 적자: \( \hat{p}^w = \{s^w_{n_1} \} \cup \{s^w_{n_i}\}^C_K \)
  • 비슷하게 strong augmentation을 해서 얻은 확률들도 다음과 같이 적자: \( \hat{p}^w = \{s^s_{n_1} \} \cup \{s^s_{n_i}\}^C_K \)
  • 식 (1)과 비슷하게 다음과 같은 loss를 사용하면 된다

\

  • 하지만 식 (2)에 있는 loss를 계산할 때 사용한 classification head (\(\h_{main}))를 기반으로 얻은 확률에 식 (5)를 적용해서 학습하면 모델 confidence가 비정상적으로 높아질 수 있다고 합니다: 확실하지 않은 unlabeled data에 대한 confidence가 비정상적으로 높아질 수 있다고 합니다.
  • 이것을 방지하기 위해서 보조적인 classification head \( h_{aux}\)를 도입해서 학습은 \( h_{aux}\)를 통해서 하면 된다고 합니다 (shrunk class, \( \hat{p}^w \)는 main classification head로 구하고, \( \hat{p}^s \)는 auxiliary classification head로 구해서 식 (5)에 있는 loss 학습)

reweighting scheme

  • 1. [0.8,0.1,0.1]같은 확률 분포와 [0.5, 0.3, 0.2]같은 확률 분포에서 둘 다 동일하게 1번째 class가 맞다고 확정 짓는 것은 말이 안 된다 (전자의 top-1 class에 대한 confidence가 후자보다 더 믿음직스럽다는 뜻) 2. 학습 초반에는 대부분의 unlabeled class에 대한 confidence가 높지 않고, confidence 값이 믿음직스럽지 않다. 
  • 1번을 해결하기 위해서 식 (5)를 top-1 class의 confidence로 weight를 준다 (shrunk class space가 아닌 original class space에서 확률!)

  • 2번을 해결하기 위해서 전반적으로 모델의 top-1 class의 예측 confidence  \( m^g\)이 얼마나 높은지 추적해서 \( m^g\)로 weight를 준다:
    \[ m = \frac{1}{B_u} \sum_{k=1}^{B_u} \mathbb{1} (\xi(p^w_k)\ge \tau) \]
    \[ m^g \leftarrow \gamma m^g + (1-\gamma) m\]

  • 식 6 & 8을 함께 고려해서 다음과 같은 loss를 식 5 대신에 사용한다:

최종 loss function

  • \( \mathcal{L}_x\): supervised loss, 나머지는 식 (1) & 식 (10) 참고.

결과

다양한 task에서 SOTA를 찍었다.

hyper parameter 등등은 논문 참고.

728x90
728x90