๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ› Research/Deep Learning

[๋…ผ๋ฌธ ๋ฆฌ๋ทฐ] Learning to Compare: Relation Network for Few-Shot Learning / meta-learning, few shot learning

by ๋ญ…์ฆค 2021. 10. 17.
๋ฐ˜์‘ํ˜•

๋ณธ ๋…ผ๋ฌธ์€ CVPR2018์— ๊ฒŒ์žฌ๋œ few shot learning ์ด๋ผ๋Š” ์ฃผ์ œ์˜ ๋…ผ๋ฌธ์ž…๋‹ˆ๋‹ค. 

 

๋”ฅ๋Ÿฌ๋‹์—์„œ ๋ฐ์ดํ„ฐ์˜ ๊ฐœ์ˆ˜๋Š” ์„ฑ๋Šฅ๊ณผ ์ง๊ฒฐ๋˜์ง€๋งŒ, ํ˜„์‹ค์ ์ธ ํ…Œ์Šคํฌ์—์„œ ๋ฐ์ดํ„ฐ ๊ฐœ์ˆ˜๋Š” ๋Š˜ ๋ถ€์กฑํ•  ์ˆ˜ ๋ฐ–์— ์—†์Šต๋‹ˆ๋‹ค. 

์ด๋Ÿฌํ•œ limited data ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด data ์ฐจ์›์—์„œ๋Š” data augmentation ๋ฐฉ๋ฒ•์ด ์กด์žฌํ•˜๊ณ , network ์ฐจ์›์—์„œ๋Š”  Un/Semi-supervised learning, Transfer learning, Meta learning ๋ฐฉ๋ฒ• ๋“ฑ์ด ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. few shot learning์€ meta learning ๋ฐฉ์‹์„ ์‚ฌ์šฉํ•˜์—ฌ ์ ์€ data ๊ฐœ์ˆ˜๋กœ network๋ฅผ ํ•™์Šต์‹œํ‚ค๋Š” ๋ฐฉ๋ฒ•๋ก ์ž…๋‹ˆ๋‹ค. Meta learning์—๋Š” metric, model, optimization, GCN ๋“ฑ ๋‹ค์–‘ํ•œ base์˜ ์—ฐ๊ตฌ๊ฐ€ ์กด์žฌํ•ฉ๋‹ˆ๋‹ค. 

 

Few shot learning์€ ๋งค์šฐ ์ ์€ ๋ฐ์ดํ„ฐ ๊ฐœ์ˆ˜๋กœ ๋„คํŠธ์›Œํฌ๋ฅผ ํ•™์Šต์‹œํ‚ค๋Š” ๋ฐฉ๋ฒ•์ด๋ฉฐ 'N-way K-shot' ๋ฌธ์ œ๋ผ๊ณ  ๋ถ€๋ฆ…๋‹ˆ๋‹ค. N์€ class์˜ ๊ฐœ์ˆ˜์ด๋ฉฐ, K๋Š” class๋ณ„ ๋ฐ์ดํ„ฐ์˜ ์ˆ˜๋ฅผ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค. K๊ฐ€ ๋งŽ์„์ˆ˜๋ก class๋ณ„ ์ธ์Šคํ„ด์Šค์˜ ์ˆ˜๊ฐ€ ๋งŽ์•„์ง€๋ฏ€๋กœ ์„ฑ๋Šฅ์ด ์ข‹์•„์ง€๊ณ  N์ด ๋งŽ์•„์ง€๋ฉด ๋ถ„๋ฅ˜ํ•ด์•ผํ•  class์˜ ์ˆ˜๊ฐ€ ๋Š˜์–ด๋‚˜๊ธฐ ๋•Œ๋ฌธ์— ์„ฑ๋Šฅ์ด ๋–จ์–ด์ง‘๋‹ˆ๋‹ค. Few shot learning์€ K๊ฐ€ ๋งค์šฐ ์ž‘์€ ์ƒํ™ฉ์„ ๊ฐ€์ •ํ•˜๋ฉฐ ๋Œ€๋ถ€๋ถ„ ์—ฐ๊ตฌ์—์„œ๋Š” benchmark๋กœ 5-way 1-shot, 5-way 5-shot ๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

 

Few shot learning ๋ชจ๋ธ์ด ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ(ํ•™์Šตํ•˜์ง€ ์•Š์€ class)์—์„œ๋„ ์ž˜ ๋™์ž‘ํ•˜๋„๋ก episodic training ๋ฐฉ์‹์˜ meta-learning์„ ์‚ฌ์šฉํ•˜์—ฌ ์ผ๋ฐ˜ํ™” ์„ฑ๋Šฅ์„ ๊ทน๋Œ€ํ™”์‹œํ‚ต๋‹ˆ๋‹ค. 

*Meta-learning : ์ ์€ ์ˆ˜์˜ data๋กœ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค์œ„ํ•œ learn to learn ๋ฐฉ์‹์˜ ํ•™์Šต ๋ฐฉ๋ฒ•์„ ์ง€์นญ

 

Episode Learning

Episodic training์€ ๋Œ€์šฉ๋Ÿ‰ training dataset์—์„œ N-way K-shot์˜ support set 1๊ฐœ์™€ query set 1๊ฐœ๋กœ 1๊ฐœ์˜ ์—ํ”ผ์†Œ๋“œ๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค. ์—ฌ๋Ÿฌ๊ฐœ์˜ ์—ํ”ผ์†Œ๋“œ๋ฅผ ์ƒ์„ฑํ•˜์—ฌ ๋„คํŠธ์›Œํฌ๋ฅผ ํ•™์Šต์‹œํ‚ค๊ณ  test์‹œ์—๋Š” training dataset์— ์กด์žฌํ•˜์ง€ ์•Š์•˜๋˜ class๋กœ support set๊ณผ query set์„ ๋งŒ๋“ค์–ด ํ…Œ์ŠคํŠธํ•ฉ๋‹ˆ๋‹ค.

 

1) Training data์—์„œ n-way k-shot ์˜ data ์ถ”์ถœํ•˜์—ฌ support set ๊ตฌ์„ฑ

 

2) Training data์—์„œ support set์— ํฌํ•จ๋œ class๋กœ query set ๊ตฌ์„ฑ 

3) Query set์„ ํ‰๊ฐ€ํ•˜๊ณ  ๊ณ„์‚ฐ๋œ loss๋กœ Network์„ update

4) 1~3๋ฒˆ์„ ๋ฐ˜๋ณตํ•˜๋ฉฐ training set์—๋Š” ์—†๋˜ class๋ฅผ ๊ฐ€์ง€๋Š” test dataset์œผ๋กœ support/query set์„ ๋งŒ๋“ค์–ด ์„ฑ๋Šฅ์„ ํ…Œ์ŠคํŠธํ•ฉ๋‹ˆ๋‹ค.

 

 

๋ฐ˜์‘ํ˜•

 

Metric Learning Approach

Metric-based ๋ฐฉ๋ฒ•์€ ์ด๋ฏธ์ง€๊ฐ€ feature space๋กœ embedding๋˜์—ˆ์„ ๋•Œ ๋™์ผํ•œ class๋ผ๋ฆฌ๋Š” feature๊ฐ€ ๋งค์šฐ ๊ฐ€๊นŒ์šด ๊ฑฐ๋ฆฌ์— ์žˆ๋„๋ก ํ•™์Šต์‹œํ‚ค๊ณ  ์„œ๋กœ ๋‹ค๋ฅธ class๋ผ๋ฆฌ๋Š” ๋จผ ์œ„์น˜์— ์žˆ๋„๋ก ํ•™์Šต์‹œํ‚ค๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค. feature space์—์„œ feature ๊ฐ„ ๊ฑฐ๋ฆฌ๋Š” euclidean metric, cosine metric ๋“ฑ์˜ ๋ฐฉ๋ฒ•์œผ๋กœ ๋ฏธ๋ถ„๊ฐ€๋Šฅํ•œ distance๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. Ground truth๋Š” 2๊ฐœ ์ธ์Šคํ„ด์Šค๊ฐ€ ์„œ๋กœ ๋™์ผํ•œ class์ธ ๊ฒฝ์šฐ 1๋กœ, ์„œ๋กœ ๋‹ค๋ฅธ class์ธ ๊ฒฝ์šฐ 0์œผ๋กœ ์ง€์ •ํ•˜์—ฌ similarity๊ฐ€ ๋†’์€ data ์Œ์ผ์ˆ˜๋ก output์ด 0์—์„œ 1์— ๊ฐ€๊นŒ์›Œ์ง€๋„๋ก ํ•™์Šต์‹œํ‚ต๋‹ˆ๋‹ค.

 

๊ฐ€์žฅ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ๋Š” ์•„๋ž˜ ๊ทธ๋ฆผ๊ณผ ๊ฐ™์ด ํ•˜๋‚˜์˜ network์— 2๊ฐœ์˜ ์ด๋ฏธ์ง€๋ฅผ ๋„ฃ๊ณ  ์ž„๋ฒ ๋”ฉ๋œ feature 2๊ฐœ๋กœ feature distance๋ฅผ ์žฌ๋Š” ์ƒด ๋„คํŠธ์›Œํฌ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.

์ฆ‰, ์ด๋Ÿฌํ•œ ๋ฐฉ์‹์œผ๋กœ ํ•™์Šต์‹œํ‚ค๋ฉด ์ผ๋ฐ˜์ ์ธ classification ๋ฌธ์ œ์™€๋Š” ๋‹ฌ๋ฆฌ ์ด๋ฏธ class ์ •๋ณด๋ฅผ ์•Œ๊ณ  ์žˆ๋Š” ์ด๋ฏธ์ง€ 1์žฅ๊ณผ ํ…Œ์ŠคํŠธ ์ด๋ฏธ์ง€ 1์žฅ์„ ๋„ฃ์–ด 2๊ฐœ์˜ ์ด๋ฏธ์ง€๊ฐ€ ๋™์ผํ•œ ํด๋ž˜์Šค์ธ์ง€ ๋™์ผํ•˜์ง€ ์•Š์€ ํด๋ž˜์Šค์ธ์ง€๋ฅผ ๋ถ„๋ฅ˜ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

ํ•˜์ง€๋งŒ, ์—ญ์‹œ ํ…Œ์ŠคํŠธํ•  ๋ฐ์ดํ„ฐ์˜ class๋กœ ํ•™์Šตํ•˜์ง€ ๋ชปํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ•™์Šต ๋ฐ์ดํ„ฐ์™€ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์˜ domain difference๊ฐ€ ํฐ ๊ฒฝ์šฐ์—๋Š” ์„ฑ๋Šฅ์ด ๋งŽ์ด ์ €ํ•˜๋ฉ๋‹ˆ๋‹ค.

 

 

Relation Network

previous work์—์„œ๋Š” euclidean distance, cosine distance๋ฅผ ์ด์šฉํ•˜์—ฌ ํ•™์Šต์‹œ์ผฐ๋Š”๋ฐ, ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” Relation Module(RM) ์ด๋ผ๋Š” fc layer๋ฅผ ํ†ตํ•ด 2์žฅ์˜ ์ด๋ฏธ์ง€๋กœ ๋ถ€ํ„ฐ ์ถ”์ถœ๋œ 2๊ฐœ์˜ feature๋ฅผ concat ์‹œํ‚ค๊ณ  RM์„ ๊ฑฐ์ณ relation score๋ฅผ ์ƒ์„ฑํ•˜๊ฒŒ ๋˜๊ณ  ์ด๋Š” ์ด๋ฏธ์ง€ 2์žฅ์˜ similarity๋ฅผ ๋‚˜ํƒ€๋‚ด๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

์ฆ‰, ์ด์ „ ์—ฐ๊ตฌ์—์„œ๋Š” ์ •ํ•ด์ง„ ์—ฐ์‚ฐ(euclidean metric, cosine metric)์œผ๋กœ 2๊ฐœ feature ๊ฐ„์˜ ๊ฑฐ๋ฆฌ๋ฅผ ๊ตฌํ–ˆ๋‹ค๋ฉด ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” feature 2๊ฐœ๋ฅผ ์—ฐ๊ฒฐํ•˜๊ณ  fc layer ๋กœ similarity ๋ฅผ ํ•™์Šตํ•˜์—ฌ relation score๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ํ•™์Šต๊ฐ€๋Šฅํ•œ ๋ฐฉ์‹์œผ๋กœ similarity๋ฅผ ๊ตฌํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ด์ „ ์—ฐ๊ตฌ์— ๋น„ํ•ด ์„ฑ๋Šฅ์ด ์ข‹๊ฒŒ ๋‚˜์˜ค๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

 

 

Few-shot classification accuracies on miniImagenet

๋ฐ˜์‘ํ˜•