728x90
728x90
개요
논문 링크: https://arxiv.org/abs/1703.01780
Semi-supervised learning 기법 중 하나. Labeled data와 unlabeled data를 동시에 학습할 때 unlabeled data에 대한 classification target (label)을 teacher network라는 별도의 network를 사용해서 얻는다. 이때 teacher network는 학습하는 network의 parameter들을 exponential moving average를 취해서 얻는 것이 핵심 아이디어다.
Figure 1을 보면서 더 자세히 소개를 하겠다.
- data가 부족하면 overfitting될 수 있다 (a).
- 이런 경우 data의 의미를 바꾸지 않는 noise를 넣어주면 data 주변에도 모델이 더 robust 해진다. (b)
- label이 없는 데이터를 활용하기 위해서 \( \Gamma\) model이라는 기법에서는 unlabeled data에 대해서 noise가 있을 때 결과와 noise가 없을 때 결과가 같게 나오도록 consistency loss를 줘서 학습한다. 이 경우 같은 모델이 teacher이 되기도 하면서 (noisy unlabeled data로 target을 만드는 역할) student가 되기도 한다 (labeled data와 teacher이 만드는 target을 학습). 이렇게 얻은 label에 대한 loss를 너무 강조하면 실제 label에 대한 정보를 학습하는데 방해가 되고 confirmation bias가 생겨서 (c)처럼 학습될 수 있다.
- 이것을 해결하기 위해서 noise를 잘 선택하는 방법이 있다 (Virtual Adversarial Training). 또 다른 방법은 teacher을 잘 선택하는 방법이 있다. 이 논문에서는 후자에 대한 분석을 한다. 한 가지 방법은 teacher model에 noise를 넣어서 ensemble 한 target을 만드는 것이다. (d)에서 속이 빈 작은 동그라미들이 noise를 넣은 target들이고 속이 빈 큰 동그라미가 이런 target들을 averaging 해서 만든 target이다. 이것을 \( \Pi\) model이라고 부른다.
- Temporal Ensembling 을 사용해서 \( \Pi\) model을 더 좋게 만들 수 있다. 여기서 Temporal Ensembling이란 prediction에 대해서 exponential moving average (EMA)를 취해서 unlabeled data에 대한 target (i.e. label)을 얻는 것을 말한다. 이것은 과거의 모델들을 ensembling 하는 효과가 있다 (e). 하지만 이렇게 만든 target은 업데이트가 느리다 (epoch가 지나야 업데이트가 된다).
방법
아이디어
- Mean Teacher의 핵심 아이디어는 model parameter에 대해서 EMA를 취해서 teacher model을 만드는 것이다. 이 teacher model이 unlabeled data에 target을 만들게 된다.
- Temporal Ensembling은 unlabeled data의 target 값의 업데이트가 느리지만, Mean Teacher은 gradient descent를 할 때마다 업데이트가 된다.
- unlabeled data에 대한 과거 label들을 저장할 필요가 없어서 큰 데이터셋에 대해서 적용하는데 무리가 없고 on-line learning이 가능해진다
Loss
- Labeled data에 대해서는 평소대로 classification loss 사용. 추가로 모든 데이터에 대해서 consistency loss를 준다:
\[ J(\theta) = E_{x,\eta,\eta'} [|| f(x,\theta',\eta') - f(x,\theta,\eta)||^2]\]- 여기서 f는 neural network이고, \( \theta\)가 모델 parameter이다. Teacher network같은 경우 EMA로 구한다: \( \theta'_t = \alpha \theta'_{t-1} + (1-\alpha) \theta_t\)
- \( \eta, \eta'\)는 noise
- Figure 참고
결과
- SVHN dataset, CIFAR-10 dataset에 대해 실험 결과는 Table 1, 2 참고. 데이터가 많이 없는 경우 확실히 label이 있는 데이터만 사용했을 때에 비해서 성능이 좋아진다.
- Virtual Adversarial Training이 조금 더 좋은 성능을 주는 것을 볼 수 있다.
728x90
728x90
'논문 리뷰 > semi-supervised learning' 카테고리의 다른 글
[논문 리뷰] Unsupervised Data Augmentation for Consistency Training (0) | 2023.06.03 |
---|---|
[논문 리뷰] ReMixMatch (0) | 2023.04.06 |
[논문 리뷰] Virtual Adversarial Training (0) | 2023.04.02 |
[논문 리뷰] MixMatch (0) | 2023.03.22 |
[논문 리뷰] Pseudo label (0) | 2023.03.12 |