728x90
728x90
개요
논문 링크: https://arxiv.org/abs/1905.02249 (MixMatch: A Holistic Approach to Semi-Supervised Learning)
참고하면 좋은 논문: MixUp: https://arxiv.org/abs/1710.09412
Semi-supervised learning 기법 중 하나. unlabeled data에 label을 단순히 teacher model로 만들지 않고, MixUp을 사용해서 labeled data와 unlabeled data를 섞어버린다. labeled data와 unlabeled data의 '중간' 데이터들을 MixUp으로 만들어서 unlabeled data에 대한 label의 quality를 좋게 만들어 주겠다는 것이 핵심 아이디어.
방법
Unlabeled data에 label 얻는 방법
알고리즘은 다음과 같다
- \( \mathcal{X}\): batch of labeled examples, labels = one-hot vectors
- \( \mathcal{U}\): batch of unlabeled examples
- MixMatch 는 \( \mathcal{X}\), \( \mathcal{U}\) 에 MixUp을 적용해서 augmented labeled data \( \mathcal{X}'\), augmented unlabeled data with guessed label \( \mathcal{U}'\)를 얻는다.
\[ \mathcal{X}',\mathcal{U}' = \textrm{MixMatch}(\mathcal{X},\mathcal{U},T,K,\alpha) \quad (2)\]- T, K, \( \alpha\): hyperparameters, 밑에서 추가 설명
- 우선 labeled data \( x_b \in \mathcal{X}\) 를 augmentation 해서 \( \hat{x}_b\)를 얻고, unlabeled data \( u \in \mathcal{U}\) 를 K 번 augmentations 해서 \( \{ \hat{u}_{b,1}, ..., \hat{u}_{b,K}\} \)를 얻는다.
- 위에서 얻은 augmented unalbeled data를 모델에 통과시킨 후 얻은 label을 평균 내서 guessed label을 얻는다:
\[ \bar{q}_b = \frac{1}{K} \sum_{k=1}^{K} p_{model} (y|\hat{u}_{b,k};\theta) \quad (4) \] - 식 4에서 얻은 label은 confident하지 않을 수 있어서 sharpening해준다 (label의 entropy를 줄여주겠다는 뜻; [논문 리뷰] Pseudo label 와 비슷한 아이디어):
\[ Sharpen(p,T)_i = \frac{p_i^{1/T}}{\sum_{j=1}^L p_j^{1/T}} \quad (7)\]
여기서 p는 class label에 대한 분포; \( \bar{q}_b\)라고 생각해도 무관. T는 temperature hyperparameter. - 이렇게 얻은 guessed label data & labeled data로 MixUp을 해준다.
- MixUp을 하기 위해서 우선 labeled batch, unlabeled batch를 섞어서 \( \mathcal{W}\)를 만든다:
Labled batch: \( \hat{\mathcal{X}} = ( (\hat{x}_b,p_b); b \in (1,...,B)) \)
Unlabled batch: \( \hat{\mathcal{U}} = ( (\hat{x}_{b,k},q_b); b \in (1,...,B), k \in (1,...,K)) \) (위에서 언급했었던 집합)
\( \mathcal{W}\): \( \hat{\mathcal{X}} \)와 \( \hat{\mathcal{U}} \)를 섞은 집합 - 두 개의 데이터 쌍 \( (x_1, p_1)\), \( (x_2, p_2)\)이 있을 때 이 두 데이터를 MixUp해서 \( (x',p')\)를 얻는다. MixUp은 다음과 같이 한다. 참고로 식 (9)는 원래 MixUp에 없다; 여기서 \( x_1\)에 더 가까운 MixUp 된 데이터를 얻게 하기 위해서 도입되었다. 밑에 MixUp 사용 방식을 보면 이해가 될 것이다.
- MixUp을 하기 위해서 우선 labeled batch, unlabeled batch를 섞어서 \( \mathcal{W}\)를 만든다:
- \( \mathcal{X}' = MixUp(\hat{\mathcal{X}}_i, \mathcal{W}_i), i \in 1,...,|\hat{\mathcal{X}}|\) (\( \mathcal{W}\) 에서 \( |\hat{\mathcal{X}}|\)개의 샘플을 사용해서 \( \hat{\mathcal{X}}_i\)와 MixUp을 한다는 뜻; 식 9 때문에 labeled data \( \hat{\mathcal{X}}_i\) 에 더 가까운 MixUp이 된다)
- \( \mathcal{U}' = MixUp(\hat{\mathcal{U}}_i, \mathcal{W}_{i+|\hat{\mathcal{X}}|}), i \in |\hat{\mathcal{U}}|\) (\( \mathcal{W}\) 에서 나머지 \( |\hat{\mathcal{U}}|\)개의 샘플을 사용해서 \( \hat{\mathcal{U}}_i\)와 MixUp을 한다는 뜻; 식 9 때문에 unlabeled data \( \hat{\mathcal{U}}_i\) 에 더 가까운 MixUp이 된다)
- 원래 MixUp에서는 식 (9) 가 없지만 여기서는 x 혹은 u 에 가까운 mixed sample을 얻기 위해서 도입했다.
Loss function
- 위에서 설명한 방식으로 MixUp을 하고 다음 loss function을 사용한다
- H는 cross entropy, \( \lambda_{\mathcal{U}}\)는 hyperparameter.
Hyperparameter
- T=0.5, K=2로 두면 된다
- \( \alpha, \lambda_{\mathcal{U}}\)는 데이터셋마다 튜닝할 필요가 있지만 \( \alpha=0.75, \lambda_{\mathcal{U}} = 100\) 근처에서 출발해서 search하면 된다.
- \( \lambda_{\mathcal{U}} \)같은 경우 16,000 step을 걸쳐서 linear하게 0부터 증가시킨다.
Some more detail
- 약간 중요한데 모델을 Adam으로 학습시키고, 학습시키는 동안 모델 parameter을 exponential moving average를 취해서, 실제로 evaluation을 할 때는 이렇게 exponential moving average된 parameter를 사용한다.
Ablation
- 대부분 설명이 필요 없고, Interpolation Consistency Training은 unlabeled mixup만 사용하고 sharpening 을 사용하지 않은 경우를 말한다 (label guessing할 때는 exponential moving average를 사용)
- Mean teacher처럼 parameter에 EMA하는 것이 별로 도움이 안 되는 것을 볼 수 있다.
실험 결과
특별히 설명은 필요 없는 것 같다; 당시 나왔던 다른 방법들보다 좋다는 것이 결론. 다른 실험들도 있는데 궁금하면 논문 참고.
728x90
728x90
'논문 리뷰 > semi-supervised learning' 카테고리의 다른 글
[논문 리뷰] Unsupervised Data Augmentation for Consistency Training (0) | 2023.06.03 |
---|---|
[논문 리뷰] ReMixMatch (0) | 2023.04.06 |
[논문 리뷰] Virtual Adversarial Training (0) | 2023.04.02 |
[논문 리뷰] Mean teachers are better role models (0) | 2023.03.13 |
[논문 리뷰] Pseudo label (0) | 2023.03.12 |