๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[HuggingFace] Swin Transformer ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ํ•™์Šต ํŠœํ† ๋ฆฌ์–ผ

by ๋ญ…์ฆค 2023. 1. 11.
๋ฐ˜์‘ํ˜•

์ตœ๊ทผ์— ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•  ์ผ์ด ์ƒ๊ฒจ ๋”ฅ๋Ÿฌ๋‹ ํ”„๋ ˆ์ž„์›Œํฌ์ธ HuggingFace ๋ฅผ ์‚ฌ์šฉํ•ด๋ดค๋‹ค.

 

HuggingFace์˜ transformers๋Š” ๋‹ค์–‘ํ•œ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์„ ์ œ๊ณตํ•˜๊ณ  ์ž์ฒด ํ•™์Šต/ํ‰๊ฐ€ API๋ฅผ ์ œ๊ณตํ•œ๋‹ค. ๋˜ํ•œ ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ชจ๋ธ์€ Pytorch, Tensorflow ํ•™์Šต/ํ‰๊ฐ€ ์ฝ”๋“œ์—๋„ ๊ทธ๋Œ€๋กœ ์ ์šฉํ•  ์ˆ˜ ์žˆ์„๋งŒํผ ํ˜ธํ™˜์„ฑ์ด ์ข‹๋‹ค๊ณ  ์•Œ๋ ค์ ธ ์žˆ๋‹ค. 

๊ทธ๋ž˜์„œ pytorch ํ•™์Šต ์ฝ”๋“œ์— ๋ชจ๋ธ๋งŒ ํ—ˆ๊น…ํŽ˜์ด์Šค์˜ ํŠธ๋žœ์Šคํฌ๋จธ๋กœ ๋ฐ”๊ฟ”์„œ ํ•™์Šต์‹œํ‚ค๋ฉด ๋˜๋‹ˆ๊นŒ ๊ฐ„๋‹จํ•˜๊ฒ ๊ตฌ๋‚˜ ์ƒ๊ฐํ–ˆ์ง€๋งŒ... ์ƒ๊ฐ๋ณด๋‹ค ๊ณ ๋ คํ•ด์•ผํ•  ์ ๋“ค์ด ์žˆ์–ด์„œ ํŠœํ† ๋ฆฌ์–ผ๋กœ ์ •๋ฆฌํ•ด ๋ณธ๋‹ค.

 

 

 

HuggingFace Vision Model ์‚ฌ์šฉ ๋ฐฉ๋ฒ•

 

๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„

Vision Transformer(ViT) ๊ธฐ๋ฐ˜์˜ ์ด๋ฏธ์ง€ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ๋“ค์€ ์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ ์œ„์™€ ๊ฐ™์ด ํŒจ์น˜๋กœ ์ž˜๋ผ์„œ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ๊ตฌ์„ฑํ•˜๊ฒŒ ๋œ๋‹ค. ์œ„ ๊ณผ์ •์ด ๋ชจ๋ธ ์ž์ฒด์— ํฌํ•จ๋œ pytorch ๊ธฐ๋ฐ˜์˜ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ๋“ค๋„ ์žˆ์ง€๋งŒ, ํ—ˆ๊น…ํŽ˜์ด์Šค์—์„œ๋Š” ์ด๋ฏธ์ง€๋ฅผ ํŒจ์น˜ํ™” ์‹œํ‚ค๋Š” ํด๋ž˜์Šค์™€ ๋ชจ๋ธ ํด๋ž˜์Šค๊ฐ€ ๋ถ„๋ฆฌ๋˜์–ด ์žˆ๋‹ค.

 

๋•Œ๋ฌธ์— ๋ฐ์ดํ„ฐ์…‹ ์ž์ฒด๋ฅผ ๋ฏธ๋ฆฌ ํŒจ์น˜ํ™”์‹œ์ผœ๋‘๊ณ  ์‚ฌ์šฉํ•˜๊ฑฐ๋‚˜ ๋ฐ์ดํ„ฐ๋กœ๋”๋กœ ๋ฐฐ์น˜๋‹จ์œ„๋กœ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์˜ฌ ๋•Œ ํŒจ์น˜ํ™” ์‹œํ‚ฌ ์ˆ˜๋„ ์žˆ๋‹ค.

 

 

ํ•™์Šต ๋ฐ ํ‰๊ฐ€

 

ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ชจ๋ธ ํ•™์Šต์‹œ์—๋Š” ํ—ˆ๊น…ํŽ˜์ด์Šค ์ž์ฒด ํ•™์Šต API์ธ Trainer ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ๊ณ , ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ชจ๋ธ์„ pytorch ํ•™์Šต ์ฝ”๋“œ๋กœ ํ•™์Šต์‹œํ‚ฌ ์ˆ˜๋„ ์žˆ๋‹ค. ๋‹ค๋ฅธ ๋ชจ๋ธ๋“ค๊ณผ์˜ ๋น„๊ต๋ฅผ ์œ„ํ•ด ๊ธฐ์กด ํ•™์Šต ์ฝ”๋“œ์— ๋ชจ๋ธ์„ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ์•„๋‹ˆ๋ผ๋ฉด, ๊ฐœ์ธ์ ์œผ๋กœ ํ—ˆ๊น…ํŽ˜์ด์Šค ์ž์ฒด API๋ฅผ ์“ฐ๋Š”๊ฒŒ ํŽธํ•  ๊ฒƒ ๊ฐ™๋‹ค. ๋ชจ๋ธ configuration ์„ค์ •์ด๋‚˜ pre-train๋œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ์ €์žฅํ•˜๋Š” ์ฝ”๋“œ, metric ๊ด€๋ จ ํ•จ์ˆ˜๋“ค์ด ๊ต‰์žฅํžˆ ์‚ฌ์šฉํ•˜๊ธฐ ํŽธํ•˜๊ฒŒ ๋˜์–ด์žˆ๋‹ค.

 

 

HuggingFace Swin Transformer ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ํŠœํ† ๋ฆฌ์–ผ

 

HuggingFace์˜ ๊ณต์‹ ๋ ˆํผ์ง€ํ† ๋ฆฌ์— ๋‹ค์–‘ํ•œ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์˜ ํŠœํ† ๋ฆฌ์–ผ์„ ์ œ๊ณตํ•˜๊ณ  ์žˆ๋Š”๋ฐ, ๊ธฐ๋ณธ์ ์ธ ํ•™์Šต/ํ‰๊ฐ€ ์ฝ”๋“œ ํ•ด์„์ด ํ•„์š”ํ•˜์‹  ๋ถ„์€ ์•„๋ž˜ ํŠœํ† ๋ฆฌ์–ผ์„ ์ฐธ๊ณ ํ•˜๋ฉด ์ข‹์„ ๊ฒƒ ๊ฐ™๋‹ค.

 

- HuggingFace ๊ณต์‹ ํŠœํ† ๋ฆฌ์–ผ ๋ ˆํผ์ง€ํ† ๋ฆฌ : https://github.com/huggingface/transformers/tree/main/examples/pytorch

- ํŠœํ† ๋ฆฌ์–ผ ์ฝ”๋“œ : https://github.com/ldj7672/Deep-Learning-Model-Tutorials/blob/main/HuggingFace/HuggingFace_SwinT_image_classification.ipynb

 

1. ํŒจํ‚ค์ง€ ๋‹ค์šด๋กœ๋“œ & ํ™˜๊ฒฝ ์„ธํŒ…

 

  • ์šฐ์„  HuggingFace์˜ transformers์™€ datasets๋ฅผ ์„ค์น˜

 

 

2. ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„

 

  • ํ—ˆ๊น…ํŽ˜์ด์Šค ํ™ˆํŽ˜์ด์ง€์— ๊ณต๊ฐœ๋˜์–ด ์žˆ๋Š” ๋ฐ์ดํ„ฐ์…‹์˜ ๊ฒฝ์šฐ ๋ฐ์ดํ„ฐ์…‹ ์ด๋ฆ„์„ ๋ฌธ์ž์—ด๋กœ ๋„˜๊ฒจ์ฃผ๋ฉด ๋‹ค์šด ๊ฐ€๋Šฅ (ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ๊ณ ์–‘์ด์™€ ๊ฐœ ๋ถ„๋ฅ˜ ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉ)
  • ๋ฐ์ดํ„ฐ์…‹์„ ๋ช‡๊ฐœ ๋‹ค์šด๋ฐ›์•„ ๋ณด๋‹ˆ key ๊ฐ’์ด ์ผ์ •ํ•˜์ง€๊ฐ€ ์•Š์€๋ฐ ์‚ฌ์šฉ ํŽธ์˜๋ฅผ ์œ„ํ•ด ๋ณธ์ธ์˜ ๋ฃฐ๋Œ€๋กœ ์ง€์ •ํ•ด์ฃผ๊ณ , ์ธ๋ฑ์Šค๋ณ„ ๋ ˆ์ด๋ธ”์„ ์ง€์ •
    • ๋‚˜์ค‘์— ๋ชจ๋ธ ์ธํผ๋Ÿฐ์Šค ์‹œ์— 0, 1, 2 ๋“ฑ์˜ ์ˆซ์ž๊ฐ€ ์•„๋‹Œ '๊ฐœ', '๊ณ ์–‘์ด' ๊ฐ™์€ ๋ ˆ์ด๋ธ”์„ ์–ป์„ ์ˆ˜ ์žˆ์Œ

 

3. HuggingFace ๋ชจ๋ธ ์„ธํŒ…

  • ํ—ˆ๊น…ํŽ˜์ด์Šค ํ™ˆํŽ˜์ด์ง€์—๋Š” ๋‹ค์–‘ํ•œ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์˜ pre-trained ๋ชจ๋ธ์ด ๊ณต๊ฐœ๋˜์–ด ์žˆ๊ณ , ์—ญ์‹œ๋‚˜ ๋‹ค์šด๋ฐ›์•„์„œ ๋ฐ”๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Œ
  • ์•ž์„œ ์„ค๋ช…ํ–ˆ๋‹ค์‹œํ”ผ ํŠธ๋žœ์Šคํฌ๋จธ ๊ณ„์—ด ๋ชจ๋ธ์€ CNN ๊ณ„์—ด ๋ชจ๋ธ๊ณผ ๋‹ฌ๋ฆฌ ์ด๋ฏธ์ง€๋ฅผ ํŒจ์น˜๋กœ ์ž˜๋ผ์„œ ๋ชจ๋ธ์— ๋„ฃ์–ด์ค˜์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— preprocessor์™€ transformer model ๋‘ ๊ฐ€์ง€ ๋‹ค ํ•„์š”
  • ๋ชจ๋ธ configuration์„ ์ˆ˜์ •ํ•˜๋ฉด ๋ชจ๋ธ์˜ ์—ฌ๋Ÿฌ๊ฐ€์ง€ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์กฐ์ •ํ•  ์ˆ˜ ์žˆ์Œ

 

 

# ๋ฐ์ดํ„ฐ์…‹ transformation ๋ฐ preprocess

  • preprocessor์˜ ์ถœ๋ ฅ ์‚ฌ์ด์ฆˆ๋Š” ๊ณ ์ •์ด๊ธฐ ๋•Œ๋ฌธ์— ์‚ฌ์ด์ฆˆ์— ๋งž๊ฒŒ ๋ฐ์ดํ„ฐ์…‹ transformation์„ ์„ค์ •

 

 

# ๋ฐ์ดํ„ฐ์…‹ split ๋ฐ transform ์ ์šฉ

  • ์ด์ œ ๋ฐ์ดํ„ฐ์…‹์— transform์„ ์ ์šฉ์‹œํ‚ค๊ณ , ๋ฐ์ดํ„ฐ์…‹์„ ์ž˜ ์„ ์–ธํ–ˆ๋Š”์ง€ ์ด๋ฏธ์ง€์™€ ๋ ˆ์ด๋ธ”์„ ํ™•์ธ

 

 

4. HuggingFace ํ•™์Šต API 'Trainer' ์„ธํŒ…

  • pytorch ํ•™์Šต ์ฝ”๋“œ์™€ ๊ฐ€์žฅ ๋งŽ์ด ๋‹ค๋ฅธ ๋ถ€๋ถ„์ธ๋ฐ, TrainingArguemntes๋กœ ๊ฐ์ข… ํ•™์Šต ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์„ค์ •ํ•˜๊ณ  ํ—ˆ๊น…ํŽ˜์ด์Šค์˜ ์ž์ฒด ํ•™์Šต API ์ธ Trainer๋ฅผ ์„ ์–ธ
  • tokenizer์—๋Š” preprocessor๋ฅผ ์ง€์ •ํ•ด์ค˜์•ผ ์ด๋ฏธ์ง€๋ฅผ ํŒจ์น˜๋กœ ์ž˜๋ผ์„œ ํŠธ๋žœ์Šคํฌ๋จธ ๋ชจ๋ธ์— ๋„ฃ์–ด์ฃผ๊ฒŒ ๋จ

 

*์–ธ์–ด ๋ชจ๋ธ์ด๋‚˜ ๋ฉ€ํ‹ฐ ๋ชจ๋‹ฌ ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ tokenizer๋Š” ๋ฌธ์ž์—ด์„ ์ž„๋ฒ ๋”ฉํ•˜๋Š” ์—ญํ• ์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.

 

5. ๋ชจ๋ธ ํ•™์Šต

Trainer ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ฝ”๋“œ๊ฐ€ ์ •๋ง ๊ฐ„๊ฒฐํ•ด์ง. ์œ„ ์ฝ”๋“œ๋งŒ ์‹คํ–‰ํ•˜๋ฉด ํ•™์Šต์ด ์™„๋ฃŒ๋˜๊ณ  ์ง€์ •ํ•œ metric์— ๋”ฐ๋ผ ์•„๋ž˜์™€ ๊ฐ™์ด ํ•™์Šต ๊ฒฐ๊ณผ๋„ ์ถœ๋ ฅ๋œ๋‹ค.

 

6. ๋ชจ๋ธ ํ‰๊ฐ€

๋ชจ๋ธ evaluation ๋˜ํ•œ Trainer๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๊ต‰์žฅํžˆ ๊ฐ„ํŽธํ•˜๊ฒŒ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค.

 

 

7. ๋ชจ๋ธ ์ธํผ๋Ÿฐ์Šค

  • ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ชจ๋ธ์˜ ๊ฒฝ์šฐ ๋ชจ๋ธ์„ ์ €์žฅํ•˜๋ฉด ๋ชจ๋ธ ์ •๋ณด๊ฐ€ ๋‹ด๊ธด json ํŒŒ์ผ๊ณผ ๋ชจ๋ธ ์›จ์ดํŠธ๊ฐ€ ์ €์žฅ๋œ binํŒŒ์ผ์ด ์ €์žฅ
  • preprocessor ์˜ ๊ฒฝ์šฐ pre-train๋œ ๊ฒƒ์„ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•˜๊ณ  Swin Transformer์˜ ๊ฒฝ์šฐ์—๋Š” ๋ฐฉ๊ธˆ ํ•™์Šตํ•œ ๋ชจ๋ธ์ด ์ €์žฅ๋œ ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ from_pretrained()์— ๋„˜๊ฒจ์ค˜์„œ ํ•™์Šต๋œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ด

 

๋ชจ๋ธ ์ธํผ๋Ÿฐ์Šค ์‹œ์—๋„ ์—ญ์‹œ๋‚˜ ์ด๋ฏธ์ง€๋ฅผ ํŒจ์น˜ํ™”ํ•˜๊ธฐ ์œ„ํ•ด preprocessor์— ๋จผ์ € ํ†ต๊ณผ์‹œ์ผœ์„œ ํŠธ๋žœ์Šคํฌ๋จธ ์ž…๋ ฅ์„ ๋งŒ๋“ค๊ณ  ๋ชจ๋ธ์— ์ธํผ๋Ÿฐ์Šค ์‹œ์ผœ์„œ ์ตœ์ข… ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•œ๋‹ค.

 

... ๊ณ ์–‘์ด์ธ๋ฐ 2 ์—ํญ๋งŒ ํ•™์Šต์‹œ์ผœ์„œ ๊ทธ๋Ÿฐ์ง€ ๊ฐ•์•„์ง€๋ผ๊ณ  ์ธ์‹ํ•œ๋‹ค.. ใ…Žใ…Ž

 

๋ฐ˜์‘ํ˜•