๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ› Research/Deep Learning

[๋…ผ๋ฌธ ๋ฆฌ๋ทฐ] A Simple Framework for Contrastive Learning of Visual Representations / SimCLR / Self-supervised

by ๋ญ…์ฆค 2022. 1. 27.
๋ฐ˜์‘ํ˜•

Self-spuervised learning ์—์„œ ์ข‹์€ ์„ฑ๋Šฅ์„ ๋‚ด๋Š” Contrasive learning ์ด๋ผ๋Š” ๊ฐœ๋…์„ ์†Œ๊ฐœํ•˜๊ธฐ ์œ„ํ•ด ICML2020์— ๊ฒŒ์žฌ๋œ ๋ณธ ๋…ผ๋ฌธ์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

 

์•„๋ž˜ ์‚ฌ์ดํŠธ์— ๊ทธ๋ฆผ์œผ๋กœ ์„ค๋ช…์ด ์ž˜ ๋˜์–ด ์žˆ์–ด์„œ, ์ž์„ธํ•œ ๋‚ด์šฉ์€ ์ฐธ๊ณ ๋ฐ”๋ž๋‹ˆ๋‹ค.

https://amitness.com/2020/03/illustrated-simclr/ 

 

Contrasive Learning

์šฐ์„ , contrasive learning ์€ 2๊ฐœ์˜ input์„ ๋„คํŠธ์›Œํฌ์— ์ฃผ์ž…ํ–ˆ์„ ๋•Œ, ์ด๋“ค์ด similar ํ•œ input ์ธ์ง€ differentํ•œ input ์ธ์ง€๋ฅผ ๊ตฌ๋ณ„ํ•ด์ฃผ๊ธฐ ์œ„ํ•œ ํ•™์Šต ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.  ์˜ˆ๋ฅผ ๋“ค์–ด, ์•„๋ž˜ ๊ทธ๋ฆผ์—์„œ๋Š” Image๋Š” ๊ณ ์–‘์ด์™€๋Š” similar ํ•˜๊ณ  ๊ฐ•์•„์ง€, ์ฝ”๋ผ๋ฆฌ์™€๋Š” different ํ•ฉ๋‹ˆ๋‹ค.

 

 

๊ทธ๋Ÿฐ๋ฐ, ์ด๋Ÿฌํ•œ ๊ณผ์ •์„ ์œ„ ์™ผ์ชฝ ๊ทธ๋ฆผ์ฒ˜๋Ÿผ supervised ๋ฐฉ์‹์œผ๋กœ ํ•™์Šต์‹œํ‚ค๋ฉด, ์ง์ ‘ labeling์„ ํ•ด์•ผํ•˜๋Š” ๋ฒˆ๊ฑฐ๋Ÿฌ์›€์ด ์žˆ๊ธฐ ๋•Œ๋ฌธ์—(์‚ฌ์‹ค ์ด๋Ÿด๊ฑฐ๋ฉด ๊ทธ๋ƒฅ supervised๋กœ ํ•™์Šต์‹œํ‚ค๊ณ  class๋ฅผ ์ง์ ‘ ๋ถ„๋ฅ˜ํ•˜๋Š”๊ฒŒ ๋” ์ข‹๊ฒ ์ฃ ), ์ด๋ฅผ ์ž๋™ํ™”ํ•˜๊ธฐ ์œ„ํ•ด self-supervised learning ๋ฐฉ์‹์„ ์ด์šฉํ•ฉ๋‹ˆ๋‹ค.

 

์ „์ œ ์กฐ๊ฑด์€ ์ด๋ฏธ์ง€๊ฐ€ feature๋กœ ์ž˜ represent ๋˜์–ด์•ผ ํ•˜๊ณ , ๋‘ ์ด๋ฏธ์ง€๊ฐ€ embedding๋œ ๋‘ feature ๊ฐ„์˜ similarity๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๋ฉ”์ปค๋‹ˆ์ฆ˜์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.

 

์—ฌ๊ธฐ์„œ feature๊ฐ€ ์ž˜ represent ๋˜์–ด์•ผ ํ•œ๋‹ค๋Š” ๋œป์€, classification ํ•  ๋•Œ๋Š” ๋™์ผํ•œ class์˜ ์ด๋ฏธ์ง€๋“ค์ด feature map์—์„œ ์ž˜ ๋ชจ์—ฌ์žˆ๋Š” ๊ฒƒ๋„ ์ค‘์š”ํ•˜์ง€๋งŒ, ๋‹ค๋ฅธ class์™€ feature space์—์„œ ๊ตฌ๋ณ„๋งŒ ์ž˜ ๋˜๋ฉด(distance ๊ฐ€ ๋ฉ€๋ฉด) ๋ถ„๋ฅ˜ ์„ฑ๋Šฅ๋„ ๋†’์•„์ง€๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ๋ฐ contrasive learning ์—์„œ๋Š” ๋™์ผํ•œ class์˜ ์ด๋ฏธ์ง€๋“ค ๋ผ๋ฆฌ๋Š” deformation์ด ์žˆ๋”๋ผ๋„ feature space ์ƒ์—์„œ ๋งค์šฐ ๊ฐ€๊น๊ฒŒ ์œ„์น˜ํ•˜๊ธฐ๋ฅผ ๋ฐ”๋ž€๋‹ค๋Š” ๋œป์ž…๋‹ˆ๋‹ค. (์ฆ‰, ์•‰์•„ ์žˆ๋Š” ๊ณ ์–‘์ด๋‚˜ ๋ˆ„์›Œ์žˆ๋Š” ๊ณ ์–‘์ด๋‚˜ ๋™์ผํ•œ ๊ณ ์–‘์ด ์ด๋ฏ€๋กœ feature space ์—์„œ distance๊ฐ€ ๊ฐ€๊นŒ์›Œ์•ผ similar ํ•˜๋‹ค๊ณ  ๋งํ• ์ˆ˜ ์žˆ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.)

 

SimCLR Framework

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š”, ์œ„์˜ ๋ฌธ์ œ๋ฅผ self-supervised ๋ฐฉ์‹์œผ๋กœ ๋ชจ๋ธ๋งํ•˜๊ธฐ ์œ„ํ•ด 'SimCLR' ์ด๋ผ๋Š” ํ”„๋ ˆ์ž„์›Œํฌ๋ฅผ ์ œ์•ˆํ•ฉ๋‹ˆ๋‹ค. 

์ด ๋ฐฉ์‹์€ ๋งค์šฐ ๊ฐ„๋‹จํ•œ๋ฐ, original ์ด๋ฏธ์ง€๋ฅผ ์„œ๋กœ ๋‹ค๋ฅธ data augmentation(in image space)๋ฅผ ํ†ตํ•ด xi, xj๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ์ด๋ฅผ ๋™์ผํ•œ Encoder(CNN)์— ๋„ฃ์–ด์„œ feature representation hi, hj ๋ฅผ ์–ป๊ณ , fc layer๋ฅผ ํ†ต๊ณผ์‹œ์ผœ zi, zj๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค.

๊ทธ๋ฆฌ๊ณ  xi์™€ xj์˜ ์ตœ์ข… output ์ธ zi์™€  zj์˜ similarity๋ฅผ ์ตœ๋Œ€ํ™”์‹œํ‚ค๋Š” ๋ฐฉํ–ฅ์œผ๋กœ ํ•™์Šต์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

 

SimCLR training ๊ณผ์ •

 

์•„๋ž˜์—์„œ ๊ตฌ์ฒด์ ์ธ ๊ณผ์ •์„ ์„ค๋ช…ํ•˜๋ คํ•ฉ๋‹ˆ๋‹ค. ์šฐ์„ , label์ด ์—†๋Š” ์ˆ˜๋ฐฑ๋งŒ๊ฐœ์˜ ์ด๋ฏธ์ง€๋กœ training corpus๊ฐ€ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.

 

 

๋จผ์ € raw ์ด๋ฏธ์ง€์—์„œ ํฌ๊ธฐ๊ฐ€ N์ธ batch๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. (๊ทธ๋ฆผ ์˜ˆ์‹œ์—์„œ๋Š” N=2, ๋…ผ๋ฌธ์—์„œ๋Š” N=8192)

Data augmentation์€ (crop + flip + color jitter + grayscale)์˜ ์กฐํ•ฉ์„ ์ ์šฉํ•˜๋Š” random transformation function T๋ฅผ ์ •์˜ํ•˜์—ฌ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

 

batch์˜ ๊ฐ ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด 2๊ฐœ์˜ ์ด๋ฏธ์ง€ ์Œ์„ ์–ป๊ธฐ ์œ„ํ•ด data augmentation์„ ์ ์šฉํ•˜์—ฌ ์ด 4๊ฐœ(2*N =2*2 = 4)์˜ ์ด๋ฏธ์ง€๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค.

 

pair์˜ ์ด๋ฏธ์ง€๋“ค์€ encoder(๋…ผ๋ฌธ์—์„œ๋Š” resnet50)๋ฅผ ํ†ตํ•ด ๊ฐ๊ฐ ์ด๋ฏธ์ง€ representation(feature hi, hj) ์–ป๊ณ , hi, hj๋ฅผ Project Head(fc layer)์— ํ†ต๊ณผ์‹œ์ผœ์„œ non-linear transformation์„ ํ•ด์„œ zi, zj์— projectionํ•ฉ๋‹ˆ๋‹ค. 

 

๋…ผ๋ฌธ์—์„œ ๋ง์€ ๊ทธ๋Ÿด์‹ธํ•˜๊ฒŒ ํ–ˆ๋Š”๋ฐ, CNN์œผ๋กœ embeddingํ•œ feature๋ฅผ fc layer์— ํ†ต๊ณผ์‹œ์ผœ์„œ ๋น„์„ ํ˜•์„ฑ์„ ์ถ”๊ฐ€ํ•ด์ค˜์„œ ๋” ๋ณต์žกํ•œ representation๋„ ์ž˜ ํ•˜๋„๋ก ํ•ด์ฃผ๋Š” ์—ญํ• ์ธ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค. CNN๋งŒ ํ†ต๊ณผํ•œ feature๋กœ similarity๋ฅผ ๊ณ„์‚ฐํ•˜๋ฉด ๋™์ผํ•œ class ์ž„์—๋„ deformation์ด ํฌ๋ฉด(๊ณ ์–‘์ด์˜ ๊ฒฝ์šฐ ์ž์„ธ๋‚˜ ์นด๋ฉ”๋ผ ๊ฐ๋„๊ฐ€ ๋‹ค๋ฅด๋ฉด) similarity๊ฐ€ ๋‚ฎ์•„์งˆ ์ˆ˜๋„ ์žˆ์œผ๋‹ˆ๊นŒ์š”.

 

์ตœ์ข… output์€ ์•„๋ž˜ ๊ทธ๋ฆผ์ฒ˜๋Ÿผ ๊ฐ augmented ์ด๋ฏธ์ง€๋“ค์— ๋Œ€ํ•œ feature embedding z์ž…๋‹ˆ๋‹ค. ์ด์ œ ์ด z๋ฅผ ์ด์šฉํ•ด์„œ loss๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ต๋‹ˆ๋‹ค.

 

์ด์ œ feature embedding z์™€ loss function์„ ์ด์šฉํ•˜์—ฌ ๋„คํŠธ์›Œํฌ๊ฐ€ ์ด๋ฏธ์ง€ ์Œ์˜ similar ์™€ different ๋ฅผ ํ•™์Šตํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

 

๋ฐ˜์‘ํ˜•

 

Contrasive Loss ๊ณ„์‚ฐ

 

feature ๊ฐ„์˜ similarity๋Š” cosine similarity๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ณ„์‚ฐ๋ฉ๋‹ˆ๋‹ค. ์ฆ‰, ์•„๋ž˜ ๊ทธ๋ฆผ์ฒ˜๋Ÿผ ๋‘ ์ด๋ฏธ์ง€ xi, xj ์— ๋Œ€ํ•œ feature embedding์ธ zi์™€ zj์˜ consine similarity๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.

์ด๋ ‡๊ฒŒ cosine similarity๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์œ„ ๊ทธ๋ฆผ์ฒ˜๋Ÿผ ๊ฐ™์€ ์ด๋ฏธ์ง€์—์„œ augmentation ๋œ ์ด๋ฏธ์ง€๋“ค ๋ผ๋ฆฌ๋Š” similarity๊ฐ€ ๋†’๊ณ (๊ณ ์–‘์ด-๊ณ ์–‘์ด, ์ฝ”๋ผ๋ฆฌ-์ฝ”๋ผ๋ฆฌ) ์„œ๋กœ ๋‹ค๋ฅธ ์ด๋ฏธ์ง€์—์„œ augmentation ๋œ ์ด๋ฏธ์ง€๋“ค ๋ผ๋ฆฌ๋Š” similarity๊ฐ€ ๋‚ฎ๊ฒŒ(๊ณ ์–‘์ด-์ฝ”๋ผ๋ฆฌ) ๋ฉ๋‹ˆ๋‹ค. 

 

SimCLR์€ NT-Xent loss (Normalized Temperature-Scaled Cross-Entropy Loss) ๋ผ๋Š” contrasive loss๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

 

 

๋จผ์ € batch์—์„œ augmentation๋œ ์ด๋ฏธ์ง€ ์Œ์„ ํ•˜๋‚˜์”ฉ ๊ฐ€์ ธ์™€์„œ ๋‘ ์ด๋ฏธ์ง€๊ฐ€ ์œ ์‚ฌํ•œ ํ™•๋ฅ ์„ ์–ป๊ธฐ์œ„ํ•ด softmax ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ด softmax ๊ณ„์‚ฐ์€ ๋‘๋ฒˆ ์งธ augmentation ๋œ ๊ณ ์–‘์ด ์ด๋ฏธ์ง€๊ฐ€ pair์˜ ์ฒซ๋ฒˆ ์งธ ๊ณ ์–‘์ด ์ด๋ฏธ์ง€์™€ ๊ฐ€์žฅ ์œ ์‚ฌํ•  ํ™•๋ฅ ์„ ๊ตฌํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ batch์˜ ๋‚˜๋จธ์ง€ ๋ชจ๋“  ์ด๋ฏธ์ง€๋“ค์€ dissimilar image(negative pair)๋กœ ์ƒ˜ํ”Œ๋ง๋ฉ๋‹ˆ๋‹ค. 

 

 

๊ทธ๋Ÿฐ ๋‹ค์Œ ์œ„ ๊ณ„์‚ฐ์— log๋ฅผ ์Œ์ˆ˜๋กœ ์ทจํ•ด์„œ pair์— ๋Œ€ํ•œ loss๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ์ด๋ฏธ์ง€ pair๋„ loss๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ  ๋™์ผํ•œ pair์— ๋Œ€ํ•ด ์ด๋ฏธ์ง€์˜ ์œ„์น˜๊ฐ€ ๋ฐ”๋€Œ๋Š” ๊ฒฝ์šฐ์—๋„ loss๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. 

 

๋งˆ์ง€๋ง‰์œผ๋กœ, ํฌ๊ธฐ๊ฐ€ N=2์ธ batch์˜ ๋ชจ๋“  pair์— ๋Œ€ํ•œ loss๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ  ํ‰๊ท ์„ ์ทจํ•ฉ๋‹ˆ๋‹ค. Loss๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•™์Šตํ•˜๋ฉด์„œ encoder, projection header representation ์ด ํ–ฅ์ƒ๋˜์–ด feature space์—์„œ ์œ ์‚ฌํ•œ ์ด๋ฏธ์ง€๋ฅผ ๋” ๊ฐ€๊น๊ฒŒ ์œ„์น˜์‹œํ‚ต๋‹ˆ๋‹ค.

 

Downstream Tasks (Transfer learning)

 

SimCLR ๋ชจ๋ธ์ด contrasive learning ์œผ๋กœ ํ•™์Šต์ด ์™„๋ฃŒ๋˜๋ฉด transfer learning์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด base encoder์—์„œ ์–ป์€ feature representation์„ ์‚ฌ์šฉํ•˜์—ฌ classification, detection๋“ฑ์˜ down stream task์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

์•„๋ž˜ figure๋“ค์€ ๋…ผ๋ฌธ์— ์„ค๋ช…๋œ SimCLR algorithm์— ๋Œ€ํ•œ ๋‚ด์šฉ์ž…๋‹ˆ๋‹ค.

 

NT-Xent Loss

 

์•„๋ž˜๋Š” data augmentation์˜ ์˜ˆ์‹œ์™€ evaluation ๊ฒฐ๊ณผ ํ‘œ์ž…๋‹ˆ๋‹ค. augmentation์€ 2๊ฐœ๋ฅผ ์กฐํ•ฉํ•œ ๊ฒฐ๊ณผ๊ฐ€ 1๊ฐœ๋ฅผ ์“ฐ๋Š” ๊ฒฝ์šฐ๋ณด๋‹ค ์ข‹์•˜๊ณ  crop + color distortion์ด ์ œ์ผ ์ข‹์€ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

 

๋ฐ˜์‘ํ˜•