[2020 NeurIPS Oral] Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning #168
Labels
Representation Learning
Self-Supervised Learning, Manifold Learning
Vision
Related with Computer Vision tasks
이때 당시의 main stream이었던 contrastive self-supervised learning은 negative pairs가 성능에 큰 영향을 미침
따라서, 기존 연구들은 large batch size, memory bank, or mining strategy 등을 사용했음
해당 논문은 negative pairs를 사용하지 않는 self-distillation 방식의 non-contrastive ssl인 Bootstrap Your Own Latent (BYOL) algorithm을 제안함
BYOL은 image augmentation, batch size에 대해 robust하며, 다양한 benchmark에서 좋은 성능을 냄
중요하다고 생각되는 부분만 간단히 요약
1. Method
BYOL Figure
BYOL Algorithm
Negative pairs를 사용하지 않고 모델을 학습하게 되면, 모든 input에 대해 1개의 vector로 embed하는 collapse 현상이 일어날 수 있음
이러한 collapse 현상을 prevent하는 straightforward solution으로, fixed randomly initialized network로 target을 생성하는 방법을 생각할 수 있음
(model이 각 image에 대한 fixed random initialized network의 output embedding을 근사하도록 학습한다는 의미)
해당 방법으로 학습한 모델의 성능은 18.8% accuracy를 달성함
절대적인 성능 수치로만 보면 좋은 representation이라고 할 수 없지만, linear probing이 1.4%, random guessing이 0.1%인 것에 비교하면 좋은 representation을 배웠다고 할 수 있음
그렇다면 subsequent online network를 new target network로 설정하여 위 process를 반복하면 representation의 quality를 높일 수 있을 것이라고 예상할 수 있으며, 이게 BYOL임
편의성을 위해 구현은 online network의 slowly moving exponential average를 target network로 사용함
즉,
target representation을 predict하도록 online network를 학습하면 representation이 potentially enhanced
한다는 experimental finding이 BYOL의 core motivation임1.1. Description of BYOL
(online network에 predictor가 없으면 mode collapse가 일어남)
2개의 network의 encoder는 같은 architecture를 사용하며, weight만 다르게 설정
Online network는 normalized prediction, normalized target projection간의 mean squared error를 이용하여 학습
(positive pairs의 online network output과 target network output간 cosine similarity를 maximize하도록 학습하는 것과 동일)
Target network는 online network의 exponential moving average (EMA)로 update
학습이 끝나면, online network의 encoder만 이용
1.2. Intuitions on BYOL's behavior
원문
BYOL은 contrastive learning에서의 negative examples와 같은 collapse를 prevent하는 explicit term이 없기에, collapsed constant representation으로 converge한다고 생각할 수 있음
하지만 target network의 update는$\nabla_{\xi} L^{\textrm{BYOL}}_{\theta, \xi}$ 의 방향이 아님
(target network는 EMA를 통해 update되는데, update의 direction이 loss를 minimize하는 direction이 아님)
따라서, BYOL이 collapsed constant representation으로 converge할 이유가 없음
물론 이러한 undesirable equilibria는 발생할 수 있으나 실험적으로 관측되지 않았으며, 저자들은 online network에서의 prediction network가 optimal하다고 가정하면 undesirable equilibria로 converge하는 것이 unstable함을 보임
자세한 과정은 원문을 참고하되, 대략적인 흐름만 요약하면 다음과 같음
Prediction network가 optimal하다는 가정 하에, expected conditional variance에 대한 gradient로 바꿀 수 있음
어떠한 random variables X, Y, X에 대해서, Var(X|Y,Z) ≤ Var(X|Y)임
(expected conditional variance inequality)
X는 target projection, Y는 online projection, Z는 training dynamics의 stochasticity로 인한 additional variability라고 하면
Var($z'_{\xi} | z_{\theta}$ ) ≤ Var($z'_{\xi} | c$ )이기에, undesirable equilibria로 converge하는 것이 더 unstable함
만약 Var($z'_{\xi} | z_{\theta}$ )를 $\theta$ 가 아닌 $\xi$ 에 대해 minimize하게 된다면, $z'_{\xi}$ 가 constant로 minimize되는 것이 variance가 더 작아지게 됨
Target network를 EMA하지 않고 online network의 hard-copy를 해도 variability가 충분한데, target network의 sudden change는 online prediction network의 optimality 가정을 break함
따라서, target network로 online network의 hard-copy를 사용하는 것은 undesirable equilibria로 converge하는 것이 unstable하다는 것을 보장할 수 없으며, near-optimality를 유지할 수 있는 EMA로 target network를 update해야함
전반적인 empirical finding에서 왜 working하는지 + 설계 이유에 대한 당위성을 이론적으로 잘 보여줬다고 생각됨
2. Experimental evaluation
BYOL로 학습한 encoder가 성능이 좋은 representation을 가졌다는 것을 다양한 실험을 통해 보임
2.1. Linear evaluation on ImageNet
2.2. Semi-supervised training on ImageNet
2.3. Transfer to other classification tasks
2.4. Transfer to other vision tasks
3. Building intuitions with ablations
3.1. Batch size & Image augmentations
BYOL은 negative examples를 사용하지 않기에, batch size, image augmentation에 robust함
Contrastive learning은 negative pairs로 인해 batch size가 커야지만 성능이 좋은데 반해, BYOL은 batch size가 크지 않더라도 좋은 성능을 냄
저자들은 BYOL이 batch size가 128 이하일 때 성능이 안좋은 이유에 대해 batch normalization layer 때문이라고 분석
(ill behaviour of batch normalization at low batch sizes)
Contrastive learning은 color histogram만 capture하는 것을 prevent하기 위해 color augmentation이 필수적인데 반해, BYOL은 color augmentation을 넣지 않더라도 좋은 성능을 냄
3.2. Bootstrapping & Ablation to contrastive methods
Target decay rate가 [0.9, 0.999]면 충분히 좋은 성능을 냄
BYOL에 temperature parameter re-tuning을 하지 않고 negative pairs를 사용하면 성능이 악화됨
즉, 굳이 negative pairs 쓰지말고 BYOL로 학습하라는 의미
The text was updated successfully, but these errors were encountered: