Self-spuervised learning ์์ ์ข์ ์ฑ๋ฅ์ ๋ด๋ Contrasive learning ์ด๋ผ๋ ๊ฐ๋ ์ ์๊ฐํ๊ธฐ ์ํด ICML2020์ ๊ฒ์ฌ๋ ๋ณธ ๋ ผ๋ฌธ์ ์ค๋ช ํฉ๋๋ค.
์๋ ์ฌ์ดํธ์ ๊ทธ๋ฆผ์ผ๋ก ์ค๋ช ์ด ์ ๋์ด ์์ด์, ์์ธํ ๋ด์ฉ์ ์ฐธ๊ณ ๋ฐ๋๋๋ค.
https://amitness.com/2020/03/illustrated-simclr/
Contrasive Learning
์ฐ์ , contrasive learning ์ 2๊ฐ์ input์ ๋คํธ์ํฌ์ ์ฃผ์ ํ์ ๋, ์ด๋ค์ด similar ํ input ์ธ์ง differentํ input ์ธ์ง๋ฅผ ๊ตฌ๋ณํด์ฃผ๊ธฐ ์ํ ํ์ต ๋ฐฉ๋ฒ์ ๋๋ค. ์๋ฅผ ๋ค์ด, ์๋ ๊ทธ๋ฆผ์์๋ Image๋ ๊ณ ์์ด์๋ similar ํ๊ณ ๊ฐ์์ง, ์ฝ๋ผ๋ฆฌ์๋ different ํฉ๋๋ค.
๊ทธ๋ฐ๋ฐ, ์ด๋ฌํ ๊ณผ์ ์ ์ ์ผ์ชฝ ๊ทธ๋ฆผ์ฒ๋ผ supervised ๋ฐฉ์์ผ๋ก ํ์ต์ํค๋ฉด, ์ง์ labeling์ ํด์ผํ๋ ๋ฒ๊ฑฐ๋ฌ์์ด ์๊ธฐ ๋๋ฌธ์(์ฌ์ค ์ด๋ด๊ฑฐ๋ฉด ๊ทธ๋ฅ supervised๋ก ํ์ต์ํค๊ณ class๋ฅผ ์ง์ ๋ถ๋ฅํ๋๊ฒ ๋ ์ข๊ฒ ์ฃ ), ์ด๋ฅผ ์๋ํํ๊ธฐ ์ํด self-supervised learning ๋ฐฉ์์ ์ด์ฉํฉ๋๋ค.
์ ์ ์กฐ๊ฑด์ ์ด๋ฏธ์ง๊ฐ feature๋ก ์ represent ๋์ด์ผ ํ๊ณ , ๋ ์ด๋ฏธ์ง๊ฐ embedding๋ ๋ feature ๊ฐ์ similarity๋ฅผ ๊ณ์ฐํ๋ ๋ฉ์ปค๋์ฆ์ด ํ์ํฉ๋๋ค.
์ฌ๊ธฐ์ feature๊ฐ ์ represent ๋์ด์ผ ํ๋ค๋ ๋ป์, classification ํ ๋๋ ๋์ผํ class์ ์ด๋ฏธ์ง๋ค์ด feature map์์ ์ ๋ชจ์ฌ์๋ ๊ฒ๋ ์ค์ํ์ง๋ง, ๋ค๋ฅธ class์ feature space์์ ๊ตฌ๋ณ๋ง ์ ๋๋ฉด(distance ๊ฐ ๋ฉ๋ฉด) ๋ถ๋ฅ ์ฑ๋ฅ๋ ๋์์ง๊ฒ ๋ฉ๋๋ค. ๊ทธ๋ฐ๋ฐ contrasive learning ์์๋ ๋์ผํ class์ ์ด๋ฏธ์ง๋ค ๋ผ๋ฆฌ๋ deformation์ด ์๋๋ผ๋ feature space ์์์ ๋งค์ฐ ๊ฐ๊น๊ฒ ์์นํ๊ธฐ๋ฅผ ๋ฐ๋๋ค๋ ๋ป์ ๋๋ค. (์ฆ, ์์ ์๋ ๊ณ ์์ด๋ ๋์์๋ ๊ณ ์์ด๋ ๋์ผํ ๊ณ ์์ด ์ด๋ฏ๋ก feature space ์์ distance๊ฐ ๊ฐ๊น์์ผ similar ํ๋ค๊ณ ๋งํ ์ ์๊ฒ ๋ฉ๋๋ค.)
SimCLR Framework
๋ณธ ๋ ผ๋ฌธ์์๋, ์์ ๋ฌธ์ ๋ฅผ self-supervised ๋ฐฉ์์ผ๋ก ๋ชจ๋ธ๋งํ๊ธฐ ์ํด 'SimCLR' ์ด๋ผ๋ ํ๋ ์์ํฌ๋ฅผ ์ ์ํฉ๋๋ค.
์ด ๋ฐฉ์์ ๋งค์ฐ ๊ฐ๋จํ๋ฐ, original ์ด๋ฏธ์ง๋ฅผ ์๋ก ๋ค๋ฅธ data augmentation(in image space)๋ฅผ ํตํด xi, xj๋ฅผ ์์ฑํ๊ณ ์ด๋ฅผ ๋์ผํ Encoder(CNN)์ ๋ฃ์ด์ feature representation hi, hj ๋ฅผ ์ป๊ณ , fc layer๋ฅผ ํต๊ณผ์์ผ zi, zj๋ฅผ ์ป์ต๋๋ค.
๊ทธ๋ฆฌ๊ณ xi์ xj์ ์ต์ข output ์ธ zi์ zj์ similarity๋ฅผ ์ต๋ํ์ํค๋ ๋ฐฉํฅ์ผ๋ก ํ์ต์ ์งํํฉ๋๋ค.
SimCLR training ๊ณผ์
์๋์์ ๊ตฌ์ฒด์ ์ธ ๊ณผ์ ์ ์ค๋ช ํ๋ คํฉ๋๋ค. ์ฐ์ , label์ด ์๋ ์๋ฐฑ๋ง๊ฐ์ ์ด๋ฏธ์ง๋ก training corpus๊ฐ ์๋ค๊ณ ๊ฐ์ ํฉ๋๋ค.
๋จผ์ raw ์ด๋ฏธ์ง์์ ํฌ๊ธฐ๊ฐ N์ธ batch๋ฅผ ์์ฑํฉ๋๋ค. (๊ทธ๋ฆผ ์์์์๋ N=2, ๋ ผ๋ฌธ์์๋ N=8192)
Data augmentation์ (crop + flip + color jitter + grayscale)์ ์กฐํฉ์ ์ ์ฉํ๋ random transformation function T๋ฅผ ์ ์ํ์ฌ ์ฌ์ฉํฉ๋๋ค.
batch์ ๊ฐ ์ด๋ฏธ์ง์ ๋ํด 2๊ฐ์ ์ด๋ฏธ์ง ์์ ์ป๊ธฐ ์ํด data augmentation์ ์ ์ฉํ์ฌ ์ด 4๊ฐ(2*N =2*2 = 4)์ ์ด๋ฏธ์ง๋ฅผ ์ป์ต๋๋ค.
pair์ ์ด๋ฏธ์ง๋ค์ encoder(๋ ผ๋ฌธ์์๋ resnet50)๋ฅผ ํตํด ๊ฐ๊ฐ ์ด๋ฏธ์ง representation(feature hi, hj) ์ป๊ณ , hi, hj๋ฅผ Project Head(fc layer)์ ํต๊ณผ์์ผ์ non-linear transformation์ ํด์ zi, zj์ projectionํฉ๋๋ค.
๋ ผ๋ฌธ์์ ๋ง์ ๊ทธ๋ด์ธํ๊ฒ ํ๋๋ฐ, CNN์ผ๋ก embeddingํ feature๋ฅผ fc layer์ ํต๊ณผ์์ผ์ ๋น์ ํ์ฑ์ ์ถ๊ฐํด์ค์ ๋ ๋ณต์กํ representation๋ ์ ํ๋๋ก ํด์ฃผ๋ ์ญํ ์ธ ๊ฒ ๊ฐ์ต๋๋ค. CNN๋ง ํต๊ณผํ feature๋ก similarity๋ฅผ ๊ณ์ฐํ๋ฉด ๋์ผํ class ์์๋ deformation์ด ํฌ๋ฉด(๊ณ ์์ด์ ๊ฒฝ์ฐ ์์ธ๋ ์นด๋ฉ๋ผ ๊ฐ๋๊ฐ ๋ค๋ฅด๋ฉด) similarity๊ฐ ๋ฎ์์ง ์๋ ์์ผ๋๊น์.
์ต์ข output์ ์๋ ๊ทธ๋ฆผ์ฒ๋ผ ๊ฐ augmented ์ด๋ฏธ์ง๋ค์ ๋ํ feature embedding z์ ๋๋ค. ์ด์ ์ด z๋ฅผ ์ด์ฉํด์ loss๋ฅผ ๊ณ์ฐํ์ฌ ๋ชจ๋ธ์ ํ์ต์ํต๋๋ค.
์ด์ feature embedding z์ loss function์ ์ด์ฉํ์ฌ ๋คํธ์ํฌ๊ฐ ์ด๋ฏธ์ง ์์ similar ์ different ๋ฅผ ํ์ตํ๋ฉด ๋ฉ๋๋ค.
Contrasive Loss ๊ณ์ฐ
feature ๊ฐ์ similarity๋ cosine similarity๋ฅผ ์ฌ์ฉํ์ฌ ๊ณ์ฐ๋ฉ๋๋ค. ์ฆ, ์๋ ๊ทธ๋ฆผ์ฒ๋ผ ๋ ์ด๋ฏธ์ง xi, xj ์ ๋ํ feature embedding์ธ zi์ zj์ consine similarity๋ฅผ ๊ณ์ฐํฉ๋๋ค.
์ด๋ ๊ฒ cosine similarity๋ฅผ ์ฌ์ฉํ๋ฉด ์ ๊ทธ๋ฆผ์ฒ๋ผ ๊ฐ์ ์ด๋ฏธ์ง์์ augmentation ๋ ์ด๋ฏธ์ง๋ค ๋ผ๋ฆฌ๋ similarity๊ฐ ๋๊ณ (๊ณ ์์ด-๊ณ ์์ด, ์ฝ๋ผ๋ฆฌ-์ฝ๋ผ๋ฆฌ) ์๋ก ๋ค๋ฅธ ์ด๋ฏธ์ง์์ augmentation ๋ ์ด๋ฏธ์ง๋ค ๋ผ๋ฆฌ๋ similarity๊ฐ ๋ฎ๊ฒ(๊ณ ์์ด-์ฝ๋ผ๋ฆฌ) ๋ฉ๋๋ค.
SimCLR์ NT-Xent loss (Normalized Temperature-Scaled Cross-Entropy Loss) ๋ผ๋ contrasive loss๋ฅผ ์ฌ์ฉํฉ๋๋ค.
๋จผ์ batch์์ augmentation๋ ์ด๋ฏธ์ง ์์ ํ๋์ฉ ๊ฐ์ ธ์์ ๋ ์ด๋ฏธ์ง๊ฐ ์ ์ฌํ ํ๋ฅ ์ ์ป๊ธฐ์ํด softmax ํจ์๋ฅผ ์ ์ฉํฉ๋๋ค. ์ด softmax ๊ณ์ฐ์ ๋๋ฒ ์งธ augmentation ๋ ๊ณ ์์ด ์ด๋ฏธ์ง๊ฐ pair์ ์ฒซ๋ฒ ์งธ ๊ณ ์์ด ์ด๋ฏธ์ง์ ๊ฐ์ฅ ์ ์ฌํ ํ๋ฅ ์ ๊ตฌํ๋ ๊ฒ์ ๋๋ค. ์ฌ๊ธฐ์ batch์ ๋๋จธ์ง ๋ชจ๋ ์ด๋ฏธ์ง๋ค์ dissimilar image(negative pair)๋ก ์ํ๋ง๋ฉ๋๋ค.
๊ทธ๋ฐ ๋ค์ ์ ๊ณ์ฐ์ log๋ฅผ ์์๋ก ์ทจํด์ pair์ ๋ํ loss๋ฅผ ๊ณ์ฐํฉ๋๋ค. ๋ค๋ฅธ ์ด๋ฏธ์ง pair๋ loss๋ฅผ ๊ณ์ฐํ๊ณ ๋์ผํ pair์ ๋ํด ์ด๋ฏธ์ง์ ์์น๊ฐ ๋ฐ๋๋ ๊ฒฝ์ฐ์๋ loss๋ฅผ ๊ณ์ฐํฉ๋๋ค.
๋ง์ง๋ง์ผ๋ก, ํฌ๊ธฐ๊ฐ N=2์ธ batch์ ๋ชจ๋ pair์ ๋ํ loss๋ฅผ ๊ณ์ฐํ๊ณ ํ๊ท ์ ์ทจํฉ๋๋ค. Loss๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ์ตํ๋ฉด์ encoder, projection header representation ์ด ํฅ์๋์ด feature space์์ ์ ์ฌํ ์ด๋ฏธ์ง๋ฅผ ๋ ๊ฐ๊น๊ฒ ์์น์ํต๋๋ค.
Downstream Tasks (Transfer learning)
SimCLR ๋ชจ๋ธ์ด contrasive learning ์ผ๋ก ํ์ต์ด ์๋ฃ๋๋ฉด transfer learning์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ด๋ฅผ ์ํด base encoder์์ ์ป์ feature representation์ ์ฌ์ฉํ์ฌ classification, detection๋ฑ์ down stream task์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
์๋ figure๋ค์ ๋ ผ๋ฌธ์ ์ค๋ช ๋ SimCLR algorithm์ ๋ํ ๋ด์ฉ์ ๋๋ค.
์๋๋ data augmentation์ ์์์ evaluation ๊ฒฐ๊ณผ ํ์ ๋๋ค. augmentation์ 2๊ฐ๋ฅผ ์กฐํฉํ ๊ฒฐ๊ณผ๊ฐ 1๊ฐ๋ฅผ ์ฐ๋ ๊ฒฝ์ฐ๋ณด๋ค ์ข์๊ณ crop + color distortion์ด ์ ์ผ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋๋ค.