Barlow Twins: Self-Supervised Learning via Redundancy Reduction
Barlow Twins: Self-Supervised Learning via Redundancy Reduction
연관 포스트:
Abstract
- 같은 데이터에서 왜곡된(distorted/augmented) 두 개의 sample을 똑같은 network를 통과한 output의 cross-correlation matrix을 계산하여
자연스럽게 collapse를 피할 수 있는 objective function 제안- 다른 SSL과 다르게 large batch size가 필수가 아니고, network twin도 비대칭 아님
- very high-dimensional output vector에 이점 있음
1. Introduction
- SSL 방식들은 공통적으로 다양한 distortion/data augmentation에 invariant한 representation 학습을 목표로 함
- 여러가지 Siamese Network에 sample의 distorted version을 넣고 representation의 유사도를 최대화하는 방법 사용
- (1) Contrastive Methods
- ex) SimCLR 설명 포스트
- ‘positive’와 ‘negative’ sample pair를 구성
- loss function에서 각각 다르게 취급됨
- main network과 momentum encoder를 비대칭적으로 update하는 방식 사용 가능
- (2) Clustering Methods
- 2개의 distorted sample 중,
1개는 loss target
다른 1개는 위의 target을 predict - optimization scheme: k-means(DeepCluster), non-differentiable operators(SwAV, SeLa)
- 2개의 distorted sample 중,
- (3) Similar Line of Works
- ex) BYOL, SimSiam
- 비대칭
- 구조: predictor network
- parameter update: 한 개의 distorted sample로 network update
‘stop-gradient’가 문제 해결에 필수
- Barlow Twins with Redundanct-Reduction
- Barlow의 Possible Principles Underlying the Transportation of Sensory Messages 에서 나온 개념
- sensory processing은 high-redundant sensory input을 fatorial code로 recoding하는 것
- factorial code: 통계적으로 독립된 구성 요소로 이루어진 code
- twin embeddings으로 만든 cross-correlation matrix를 identity matrix과 가깝도록 object function 구성
- 큰 batch, 비대칭적인 구조(prediction network, momentum encoder, non-differentiable operators, stop-gradient) 필요 없음
- high-dimensional embedding에 유리
2. Methods
1) Description of Barlow Twins
(1) Network 구조
- distorted images에 대한 joint embedding 진행
- dataset에서 sample된 batch $X$ 에 data augmentations $\mathcal{T}$ 하여 distorted view로 구성된 batch $Y^A,\ Y^B$ 생성
- function $f_\theta$ 에 넣어서 embeddings $Z^A, \ Z^B$ 만듦 (trainable parameters $\theta$)
- $Z^A, \ Z^B$ 는 batch dimension에서 mean-centered(zero mean)로 가정
(2) Loss Function $\mathcal{L}_{\mathcal{BT}}$
\[\mathcal{L}_{\mathcal{BT}} \triangleq \underbrace{\sum_i(1-C_{ii})}_{\text{invariance term}} + \lambda \underbrace{\sum_i\sum_{i\neq j}C_{ij}^2}_{\text{redundancy reduction term}}\]- $\lambda$: 첫번째와 두번째 loss term의 균형(trading off the importance)을 위한 positive constant
- $C$: 똑같은 network의 batch dimension을 따라 나온 outputs의 cross-correlation matrix
\(C_{ij} \triangleq\frac{\sum_b\ z_{b,i}^A\ z_{b,j}^B}{\sqrt{\sum_b\ (z_{b,i}^A)^2}\ \sqrt{\sum_b\ (z_{b,j}^B)^2}}\)
- $b$: batch sample index
- $i,\ j$: networks’ outputs vector dimension index
- $C$: network output 크기의 square matrix,
-1(perfect anti-correlation) ~ 1(perfect correlation)
- invariance term
- cross-correlation matrix의 주대각선 성분들을 1에 가깝도록
- embedding이 distortion에 invariant 하도록
- redundancy reduction term
- 주대각선 성분이 아닌 element가 0이 되도록
- embedding의 vector components decorrelate
- output units 간의 redundancy 줄여줌
- embedding vector의 soft-whitening constant로 볼 수 있음 (5. Discussion에서 다룸)
- 정보 이론의 Information Bottleneck(IB) Objective 으로도 해석 가능
- Appendix A에 정리됨
- 장점
- (1) large number of negative sample이 필요 없어서 작은 batch에서도 잘 작동됨
- (2) benefits from very high-dimensional embeddings
2) Implementation Details
(1) Image Augmentation
- 각 input image는 2번 transformation 수행
- 항상 : random cropping, 224 $\times$ 224 resizing
- 랜덤 : horizontal flipping, color jittering, convert to grayscale, Gaussian blurring, solarization(과노출로 반전)
- 뒤 2개의 확률은 다름
- BYOL 논문과 똑같은 augmentation parameter 사용
(2) Architecture
- encoder: ResNet-50 (final classification layer 없이 2048 output)
- projector network
- 3 linear layer
- 모두 8192 output units
- layer 사이에 Batch Normalization layer와 ReLU 사용
- input $\quad=$ encoder $\Rightarrow\quad$ representation $\quad=$ projector $\Rightarrow\quad$ embedding
- embedding은 downstream task에 사용하고, embedding은 loss 계산에 사용
(3) Optimization
방법 | 내용 |
---|---|
epoch | 1000 |
batch size | 2048 (256~) |
optimizer | LARS |
learning rate | weight: $0.2 \times (\text{batch size}/256)$ bias & batch norm: $0.0048 \times (\text{batch size}/256)$ |
weight decay | $1.5\times 10^{-6}$ |
linear warmup | for 10 epochs |
scheduler | cosine decay, factor 1000 |
trade off parameter | $\lambda=5\times 10^{-3}$ |
기타 | bias와 batch norm parameters은 LARS adaptation과 weight decay에서 제외됨 |
3. Results
1) Linear and Semi-Supervised Evaluation on ImageNet
<Linear Evaluation on ImageNet>
<Semi-Supervised Evaluation on ImageNet>
2) Transfer to Other Datasets and tasks
<Image Classification with Fixed Features>
ResNet-50 Freeze
- Places-205: scene classification
- VOC07: multi-label image classification
- iNaturalist2018: fine-grained image classification
<Object Detection and Instance Segmentation>
Fine-tune ResNet-50
4. Ablations
- 1000 epochs $\Rightarrow$ 300 epochs
- accuracties는 ImageNet training set에서의 2048 dimension으로 학습된 결과값들
1) Loss Function Ablations
(a)Baseline ~ (g)Cross-entropy with temp.
- (b): removing invariance term(on-diagonal) $\Rightarrow$ worse
- (c): removing redundancy reduction term(off-diagonal) $\Rightarrow$ collapsed
- (d): normalize along feature dimension(unit sphere) $\Rightarrow$ slightly reduced
- embedding을 batch dimension으로 normalize(with mean subtraction)
$\Rightarrow$ embedding을 feature dimension으로 normalize(without mean subtraction)
$\Rightarrow {\text{normalized cross-correlation 아닌}} \rightarrow \text{unnormalized covariance matrix}$
- embedding을 batch dimension으로 normalize(with mean subtraction)
- (e): projector network에서 batch normalization 제거 $\Rightarrow$ barely affected
- (f): (e) + loss의 cross-correlated matrix를 cross-covariance matrix로 바꿈(batch 축으로 noramlize 안 함) $\Rightarrow$ substainally reduced
- (g): cross-entropy with temperature $\tau$ $\Rightarrow$ reduced
\(\mathcal{L}=-log\sum_i exp(C_{ii}/\tau) + \lambda log\sum_i\sum_{i\neq j}exp(max(C_{ij},0)/\tau)\)
2) Robustness to Batch Size
- performed grid search of LARS learning rate for each batch size
3) Effect of Removing Augmentations
- not robust to augmentations like SimCLR $\rightarrow$ BYOL is robust
-다르게 보면, 특정한 distortion 사용애 대해 control이 더 좋음
4) Projector Network Depth & Width
- BYOL과 SimCLR은 projector network에서 ResNet output을 엄청 줄이고 일정 수준 이상 output이 커져도 변함 없음(saturate)
- Barlow Twins는 projector network의 output이 크면 클수록 좋음
- ResNet-50의 output은 2048로 고정되었음에도 이런 결과 나옴
- 다른 방법들과 유사하게 projector network layer가 많로질수록 좋았고, 3 layer에서 saturate
5) Breaking Symmetry
- asymmetries가 성능을 더 해침
6) BYOL with a Larger Projector/Predictor/Embedding
- 성능이 좋아지진 않음
7) Sensitivity to $\lambda$
- $\lambda$에 sensitive 하지 않음
5. Discussion
1) Comparison with Prior Art
(1) InfoNCE
- Contrastive SSL 에서 자주 사용하는 loss function
- InfoNCE
- $b$: sample index
- $i$: output의 vector component index
- $z^A,\ z^B$: twin network outputs
- $\tau$: temperature in analogy to statistical physics
- Barlow Twins loss re-write
- 공통점
- distorted된 data를 twin network에 넣어도 embedding이 invariant하게, 학습된 embedding에 대해서는 variability가 maximized하는 것이 목표
- 위의 variability에 대한 측정이 batch statistics에 의존
- 차이점
InfoNCE | Barlow Twins |
---|---|
sample들의 모든 pair에 대한 pairwise distance를 최대화하므로써 embedding variability 최대화 | embedding vector decorrelation을 통해 embedding variability 최대화 |
non-parametric estimation of the entropy of the distribution of embeddings - prone to the curse of the dimensionality - require a large number of samples |
proxy entropy estimator of the distribution of embeddings under a Gaussian parameterization(Appendix A) - simplification $\rightarrow$ fewer samples, very large dimensional embeddings |
normalized along the feature dimension(cosine similarity) | normalized along the batch dimension |
trade off parameter 없음 | trade off parameter $\lambda$ 있음(Appendix A) |
hyperparameter $\tau$ 있음 - non-parametric kernel density estimation의 kernel width로 해석 가능 - batch 안의 hardest negative sample에 대한 상대적인 중요도의 weight 값 |
- |
- Diff with MoCo
- MoCo
- large batch에 대한 의존도 낮춤
- negative samples로 이루어진 dynamic dictionary ($\gt 60,000 \text{ sample embeddings}$)와 moving-averaged encoder
- Barlow Twins
- 큰 dictionary 필요없고, 작은 batch에서도 잘 작동함
- 큰 dictionary 필요없고, 작은 batch에서도 잘 작동함
- MoCo
(2) Asymmetric Twins
- BYOL과 SimSiam은 simple cosine similarity를 사용하여 contrastive term 없이도 문제를 성공적으로 해결
- BYOL
- predictor network가 대칭적 구조를 깸
- exponential moving average가 target network weight 학습을 늦춤
- 다른 연구에서 collapse를 방지하기 위해, moving average는 필수적인 요소는 아니지만,
하나의 branch에 stop-gradient를 적용하는 것과 predictor network는 필수적임을 밝힘 - batch normalization 혹은 (alternatively) group normalization 또한 collapse를 피할 수 있는 요소라고 함
- Barlow Twins처럼 large batch나 batch 안의 다른 sample과의 관계를 objective function에서 고려할 필요 없음
- 하지만, 비대칭적인 방법은 전반적인 학습 목표에 대한 최적화로 설명될 수 없음
- implementation choices나 non-trivial learning dynamics의 결과를 통해 피함
- Barlow Twins는 construction으로 문제 해결(until their principle is discovered)
(3) Whitening
- W-MSE: twin networks의 whitened embeddings 사이의 simple cosine similarity 계산하기 전에
각 batch embedding의 미분가능한 whitening operation - Barlow Twins의 redundancy reduction term이 batch embeddings의 whitening encourages
(4) Clustering
- contrastive-like 비교를 하지만, 모든 pairwise 거리를 계산하지 않음
- collapse에 약함(k-means의 empty cluster, careful implementation 필요)
- batch size에 상관없이 학습은 가능하나, cluster 개수가 batch 크기보다 크면 feature들을 보관해야 할 필요가 있음
(5) Noise as Targets
- sample들을 fixed random targets on the unit sphere
(whitening으 한 형태로 해석 가능) - single network 사용, distortion 안 사용함
- 학습될 representation의 flexibility 제한할 수 있음
(6) IMAX
- SSL 초기의 loss function
- $\vert\quad \vert$: matrix의 determinant
- $\mathcal{C}$: convariance
- Barlow Twins과 유사한 방법이나,
IMAX는 바로 information quantity로 계산하고, extra trade-off parameter $\lambda$ 없음
2) Feature Directions
- embedding dimension이 클 수록 성능이 좋아졌으나, memory 사용량이 커지기에 잉 대한 새로운 방법 추가 연구 필요
- Information Bottleneck principle의 하나의 implementation이기에 더 개발할 여지가 많음
Appendix A
- Mutual Information(MI)
- 설명된 블로그
- joint distribution $p(X,\ Y)$가 $p(X)p(Y)$와 얼마나 비슷한지 측정 \(\mathbb{I}(X;\ Y) \triangleq \mathbb{KL}(p(x,y)\Vert p(x)p(y)) = \sum_y\sum_x p(x,y)log\frac{p(x,y)}{p(x)p(y)}\)
-
X, Y가 independent하면,
$p(x,y)=p(x)p(y)\quad \Rightarrow\quad log\frac{p(x,y)}{p(x)p(y)}=1 \quad\Rightarrow\quad MI = 0$
- 이를 conditional entropy로 바꾸면 \(\mathbb{I}(X;\ Y) = \mathbb{H}(X) - \mathbb{H}(X \vert Y) = \mathbb{H}(Y) - \mathbb{H}(Y \vert X)\)
- $\mathbb{H}$: entropy
- $Y$에 대한 정보를 앎으로써 $X$에 대한 불확실성이 얼마나 감소했는지?
$\Rightarrow$ $X$가 $Y$에 얼마나 dependent한 지?
- sample에 적용된 특정한 distortions 정보는 최소화하면서 sample에 있는 정보를 최대한 보존
- $I(\cdot\ ,\ \cdot )$: Mutual Information
- $\beta$: positive scalar, information 보존하면서 distortion에 invariant 사이의 trade off
- $H(\cdot)$: entropy $\qquad H(\cdot\vert\cdot)$: conditional entropy
-
$H(Z_\theta \vert Y)$는 zero entropy를 가지므로 0으로 수렴해서 사라짐
$\Rightarrow$ 이를 정리하면 마지막 식이 나오게 됨
-
high dimensional signal의 entropy를 계산하면 single batch보다 크기가 커짐 $\Rightarrow$ Gaussian distribution 가정
- covariance function의 determinant의 log로 계산이 단순해짐 \(\mathcal{IB}_\theta = \mathbb{E}_X\ log\vert C_{Z_\theta\vert X}\vert + \frac{1-\beta}{\beta}lof\vert C_{Z_\theta}\vert\)
- simplification과 approximations
- $\frac{1-\beta}{\beta}$를 양수 $\lambda$로 대체
- covariance matrices의 determinant를 사용해서 바로 optimize하면 SoTA가 안 나옴
- 두 번째 term을 cross-correlation matrix의 Frobenius norm의 최소화로 바꿈
- loss 계산 전에 batch dimension으로 representaion이 1로 rescale 됐다면(cross-correlation은 rescaling에 invariant),
mimimization은 off-diagonal terms에만 영향을 미치고 0으로 수렴할 수 있도록 도와줌
- 두 번째 term은 원래 twin networks에서 나온 값 중 하나로 auto-correlation해서 계산해야하나,
auto-correlation이나 cross-correlation의 차이가 별로 없었음