๐Ÿ› ๏ธ Engineering/Distributed Training

PyTorch FSDP (Fully Sharded Data Parallel) ์™„๋ฒฝ ์ดํ•ดํ•˜๊ธฐ!

๋ญ…์ฆค 2025. 7. 6. 19:15
๋ฐ˜์‘ํ˜•

1. FSDP(Fully Sharded Data Parallel)์ด๋ž€?

1.1 FSDP ๊ฐœ๋…

FSDP๋Š” PyTorch์—์„œ ์ œ๊ณตํ•˜๋Š” ๊ณ ๊ธ‰ ๋ถ„์‚ฐ ํ•™์Šต ๊ธฐ๋ฒ•์œผ๋กœ, ๋ชจ๋ธ์˜ ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ GPU๋งˆ๋‹ค ๋ณต์ œํ•˜๋Š” ๊ธฐ์กด DDP ๋ฐฉ์‹๊ณผ ๋‹ฌ๋ฆฌ, ๋ชจ๋ธ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ GPU๋ผ๋ฆฌ shard(์กฐ๊ฐ) ๋‹จ์œ„๋กœ ๋‚˜๋ˆ„์–ด ์ €์žฅํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค. ์ด๋ฅผ ํ†ตํ•ด GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ๋Œ€ํญ ์ ˆ์•ฝํ•  ์ˆ˜ ์žˆ๋‹ค.

 

FSDP๋Š” GPU๋งˆ๋‹ค ๋ชจ๋ธ ์ „์ฒด๊ฐ€ ์•„๋‹Œ ์ผ๋ถ€ shard๋งŒ ์ €์žฅํ•˜๊ณ , forward ๋ฐ backward ์—ฐ์‚ฐ ์‹œ ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ GPU ๊ฐ„์— ์„œ๋กœ ๊ตํ™˜(all-gather)ํ•˜์—ฌ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ ํ›„ ๋‹ค์‹œ shard๋กœ ๋ถ„์‚ฐ ์ €์žฅ(reduce-scatter)ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ๋™์ž‘ํ•œ๋‹ค.

 

1.2 DDP vs FSDP ์ฐจ์ด

๊ตฌ๋ถ„ DDP FSDP
๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ฐ GPU๊ฐ€ ์ „์ฒด ๋ชจ๋ธ ๋ณต์ œ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ shard ๋‹จ์œ„๋กœ ๋‚˜๋ˆ ์„œ ์ €์žฅ
GPU ๋ฉ”๋ชจ๋ฆฌ GPU๋งˆ๋‹ค ๋ชจ๋ธ ์ „์ฒด๋ฅผ ์ €์žฅํ•ด ๋ฉ”๋ชจ๋ฆฌ ๋งŽ์ด ์‚ฌ์šฉ GPU๋ผ๋ฆฌ shard ๋‹จ์œ„๋กœ ๋‚˜๋ˆ„์–ด ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ
๋ฐ์ดํ„ฐ ์ฒ˜๋ฆฌ ๊ฐ GPU๊ฐ€ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ๋ฅผ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ๊ฐ GPU๊ฐ€ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ๋ฅผ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ (DDP์™€ ๋™์ผ)
GPU ๊ฐ„ ํ†ต์‹  gradient ๋™๊ธฐํ™”(all-reduce)๋งŒ ์ˆ˜ํ–‰ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ ๊ตํ™˜(all-gather/reduce-scatter) ์ถ”๊ฐ€

์ฆ‰, FSDP๋Š” ๋ชจ๋ธ ๋ณ‘๋ ฌ์˜ ๋ฉ”๋ชจ๋ฆฌ ์ ˆ์•ฝ ์žฅ์ ๊ณผ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ์˜ ๋น ๋ฅธ ๋ณ‘๋ ฌ์ฒ˜๋ฆฌ ์žฅ์ ์„ ๊ฒฐํ•ฉํ•œ ๋ฐฉ์‹์ด๋‹ค.

 

2. FSDP ์ดํ•ดํ•˜๊ธฐ

GPU0 : shard 0
GPU1 : shard 1
GPU2 : shard 2
GPU3 : shard 3

๊ทธ๋ž˜์„œ ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์œ„์™€ ๊ฐ™์ด ์—ฌ๋Ÿฌ GPU๊ฐ€ ๋‚˜๋ˆ ์„œ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค๋Š” ๊ฑด๋ฐ, ๊ทธ๋Ÿผ ์–ด๋–ป๊ฒŒ ํ•™์Šต์„ ํ•˜๋Š”๊ฑธ๊นŒ?

 

2.1 FSDP์˜ forward ๊ณผ์ •

  1. forward๋ฅผ ์‹œ์ž‘ํ•˜๊ธฐ ์ „์—,
  2. ํ•„์š”ํ•œ layer (ํŒŒ๋ผ๋ฏธํ„ฐ)๋ฅผ GPU๋ผ๋ฆฌ all-gather ํ•ด์„œ
    → ๊ฐ GPU๊ฐ€ forward ๊ณ„์‚ฐ์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ž„์‹œ๋กœ ์ „๋ถ€ ๋ชจ์Œ
  3. forward ๊ณ„์‚ฐ์ด ๋๋‚˜๋ฉด,
  4. ๋‹ค์‹œ ๊ทธ ํŒŒ๋ผ๋ฏธํ„ฐ๋“ค์„ shard ๋‹จ์œ„๋กœ ๋ถ„์‚ฐ ์ €์žฅ (reduce-scatter)ํ•ด์„œ
    → GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋‹ค์‹œ ๋น„์›๋‹ˆ๋‹ค.

์ฆ‰, ๊ฐœ๋ณ„ GPU๊ฐ€ ๋ชจ๋ธ ์ „์ฒด๋ฅผ ํ•ญ์ƒ ๋“ค๊ณ  ์žˆ์ง„ ์•Š์ง€๋งŒ, forward ํ•  ๋•Œ๋งŒ ์ž ์‹œ GPU๋ผ๋ฆฌ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๋ชจ์•„์„œ ๊ณ„์‚ฐ์„ ํ•˜๊ณ , ๋๋‚˜๋ฉด ๋‹ค์‹œ shard๋กœ ๋‚˜๋ˆ ์„œ ์ €์žฅํ•˜๋Š” ๋ฐฉ์‹์ด๋ผ๊ณ  ๋ณด๋ฉด ๋œ๋‹ค.

 

์ข€๋” ํฌ๋ฉ€ํ•˜๊ฒŒ ์ •๋ฆฌํ•˜๋ฉด -- GPU๊ฐ€ ๋ชจ๋ธ ์ „์ฒด๋ฅผ ํ•ญ์ƒ ๋“ค๊ณ  ์žˆ์ง€ ์•Š์•„๋„ forward ๊ณ„์‚ฐ์„ ํ•  ์ˆ˜ ์žˆ๋Š” ์ด์œ ๋Š”, ํ•„์š”ํ•  ๋•Œ๋งŒ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๋„คํŠธ์›Œํฌ๋กœ ๋ชจ์•„์„œ(all-gather) ๊ณ„์‚ฐํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

 

2.2 FSDP forward / backward ํ†ต์‹  ํ๋ฆ„

[ GPU0 ]   [ GPU1 ]   [ GPU2 ]   [ GPU3 ]
   |          |          |          |
   |          |          |          |
   |------ all-gather ํŒŒ๋ผ๋ฏธํ„ฐ -----> |
   |  (๋ชจ๋ธ shard๋ฅผ ์„œ๋กœ ์ฃผ๊ณ ๋ฐ›์•„)       |
   |  => ๊ฐ GPU๊ฐ€ forward์— ํ•„์š”ํ•œ     |
   |     ์ „์ฒด ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์ž„์‹œ๋กœ ๋ณด์œ        |
   |          |          |          |
   |    forward & backward ๊ณ„์‚ฐ      |
   |          |          |          |
   |<----- reduce-scatter ----------|
   | (๊ณ„์‚ฐ ๋๋‚œ ํ›„ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๋‹ค์‹œ         |
   |     shard ๋‹จ์œ„๋กœ ๋‚˜๋ˆ  ์ €์žฅ)        |
   |          |          |          |

 

 

2.3 ๊ฒฐ๊ตญ ์ˆœ๊ฐ„์ ์ด๋”๋ผ๋„ ๋ชจ๋ธ ์ „์ฒด๋ฅผ ๊ฐœ๋ณ„ GPU์— ์˜ฌ๋ฆฌ๋Š” ๊ฒƒ ์•„๋‹Œ๊ฐ€??

 

๊ณต๋ถ€ํ•˜๋‹ค๋ณด๋‹ˆ ์˜๋ฌธ์ด ๋“ ๋‹ค. ๊ทธ๋Ÿผ FSDP๋ฅผ ์“ฐ๋ฉด ๊ฒฐ๊ตญ forward/backward ํ•  ๋•Œ ๊ฐœ๋ณ„ GPU๊ฐ€ ๋ชจ๋ธ ์ „์ฒด๋ฅผ ๋“ค๊ณ  ์žˆ๋Š” ์…ˆ์ธ๋ฐ, ๊ทธ๋Ÿผ ์™œ ๊ตณ์ด ์ƒค๋“œ๋ฅผ ๋‚˜๋ˆ„๋Š”๊ฐ€? ๋ชจ๋ธ ์ „์ฒด๋ฅผ ๋“ค๊ณ  ์žˆ์„ ์ˆ˜ ์žˆ๋Š”๋ฐ?? GPU ๊ฐ„ ํ†ต์‹ ๋งŒ ๋งŽ์•„์ง€๋Š”๋ฐ ๋ญ๊ฐ€ ํšจ์œจ์ ์ด๋ผ๋Š” ๊ฑธ๊นŒ??

 

โœ… ์šฐ์„  ๋ฉ”๋ชจ๋ฆฌ์— ์˜ฌ๋ผ๊ฐ€๋Š” ๊ฐ’๋“ค์„ ์•Œ์•„๋ณด์ž

์š”์†Œ ์ƒ์„ฑ ์‹œ๊ธฐ ์™œ ํ•„์š”ํ•œ๊ฐ€ GPU ๋ฉ”๋ชจ๋ฆฌ ์˜ํ–ฅ ํŠน์ง•
weights ๋ชจ๋ธ ์ดˆ๊ธฐํ™” forward ๊ณ„์‚ฐ ๋ฐ ์—…๋ฐ์ดํŠธ ์ƒ๋Œ€์ ์œผ๋กœ ์ž‘์Œ - ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ(W, b)
- ํ•ญ์ƒ GPU์— ์ƒ์ฃผ
gradients backward  weights ์—…๋ฐ์ดํŠธ weights์™€ ๋น„์Šทํ•จ - ∂Loss/∂W ๊ณ„์‚ฐ ๊ฒฐ๊ณผ
- backward ํ›„ optimizer๊ฐ€ ์‚ฌ์šฉ
activations forward ๊ณ„์‚ฐ ์ค‘๊ฐ„ backward ์‹œ gradient chain rule ๊ณ„์‚ฐ์— ํ•„์š” ๊ฐ€์žฅ ํผ - ๊ฐ layer์˜ ์ถœ๋ ฅ๊ฐ’
- batch size, sequence length์— ๋”ฐ๋ผ ๊ธ‰๊ฒฉํžˆ ์ฆ๊ฐ€

 

์ฆ‰, ๋ชจ๋ธ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ weights์ด๊ณ  forward๋ฅผ ํ•˜๋ฉด์„œ ์ƒ์„ฑ๋˜๋Š” ๊ฐ’๋“ค์ด activations์ด๊ณ , ์ด activations๋ฅผ ํ™œ์šฉํ•ด์„œ backward ํ•  ๋•Œ gradients๋ฅผ ๊ตฌํ•˜๋Š” ๊ฑฐ๋ผ ๋ณด๋ฉด ๋œ๋‹ค. 

 

 

โœ… FSDP ์‹ค์ œ ์—ฐ์‚ฐ ๊ณผ์ •

layer1 : ํ•„์š”ํ•œ shard๋งŒ all-gather → forward → reduce-scatter
layer2 : ํ•„์š”ํ•œ shard๋งŒ all-gather → forward → reduce-scatter
...

 

  • ๊ฒฐ๋ก ๋งŒ ์–˜๊ธฐํ•˜๋ฉด FSDP๋Š” forward/backward ๋™์•ˆ ๋ชจ๋ธ ์ „์ฒด๋ฅผ ํ•œ๊บผ๋ฒˆ์— all-gather ํ•ด์„œ GPU์— ์˜ฌ๋ฆฌ๋Š” ๋ฐฉ์‹์ด ์•„๋‹ˆ๋‹ค.
  • FSDP๋Š” ๋ชจ๋ธ์„ ์—ฌ๋Ÿฌ ๋ ˆ์ด์–ด ๋˜๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ bucket ๋‹จ์œ„๋กœ ์ชผ๊ฐœ์„œ ๊ด€๋ฆฌํ•œ๋‹ค.
  • ๊ทธ๋ž˜์„œ forward์—์„œ ๋ ˆ์ด์–ด ํ•˜๋‚˜๋ฅผ ๊ณ„์‚ฐํ•  ๋•Œ
    • ๊ทธ ๋ ˆ์ด์–ด์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ shard๋งŒ GPU๋ผ๋ฆฌ all-gather ํ•ด์„œ ์ „์ฒด ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์žฌ๊ตฌ์„ฑ
    • forward ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰
    • ์—ฐ์‚ฐ์ด ๋๋‚˜๋ฉด ๋‹ค์‹œ reduce-scatter ํ•ด์„œ shard๋งŒ ๋‚จ๊ธฐ๊ณ  ๋‚˜๋จธ์ง€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋น„์šด๋‹ค.
  • backward๋„ ๊ฐ™์€ ๋ฐฉ์‹์œผ๋กœ ์—ญ์ˆœ์œผ๋กœ ์ˆ˜ํ–‰ํ•œ๋‹ค.

๊ทธ๋ž˜์„œ forward/backward ๋™์•ˆ์—๋„ GPU๊ฐ€ ๋ชจ๋ธ ์ „์ฒด๋ฅผ ๋ฉ”๋ชจ๋ฆฌ์— ๋™์‹œ์— ์˜ฌ๋ฆฌ๊ณ  ์žˆ๋Š” ์ˆœ๊ฐ„์€ ์—†๋‹ค. ๋ชจ๋ธ ํ•˜๋‚˜๋ฅผ ํ†ต์งธ๋กœ GPU์— ์˜ฌ๋ฆด ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ์—๋„, FSDP๋Š” layer-by-layer ๋˜๋Š” bucket-by-bucket ์œผ๋กœ ๊ณ„์‚ฐ์„ ์ด์–ด๋‚˜๊ฐ€๋ฉด์„œ ํ•™์Šต์ด ๊ฐ€๋Šฅํ•ด์ง„๋‹ค.

FSDP๊ฐ€ ๋ ˆ๊ฑฐ์‹œ bucket ๋ฐฉ์‹์œผ๋กœ ๋ฌถ์€ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ chunk ๋‹จ์œ„๋กœ all-gather ํ•˜๋Š”๋ฐ, ์ด chunk size๋ฅผ ํฌ๊ฒŒ ์žก์œผ๋ฉด ์‚ฌ์‹ค์ƒ forward ์ดˆ๋ฐ˜์— ๋งŽ์€ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํ•œ๊บผ๋ฒˆ์— GPU๋กœ ๋ชจ์•„์„œ ์—ฐ์‚ฐํ•˜๊ธฐ ๋•Œ๋ฌธ์— peak๊ฐ€ ์ปค์ ธ์„œ DDP์™€ ์œ ์‚ฌํ•œ ๋ฉ”๋ชจ๋ฆฌ ํŒจํ„ด์„ ๋ณด์ด๋Š” ๊ฒฝ์šฐ๊ฐ€ ์žˆ๋‹ค. ๊ทธ๋ž˜์„œ ์ด๊ฑธ ๋‹จ์ˆœํ™”ํ•ด “forward/backward ๋•Œ GPU๊ฐ€ ์ „์ฒด๋ฅผ ๋“ค๊ณ  ์žˆ๋‹ค”๊ณ ๋“ค ํ”ํžˆ ๋งํ•จ.
ํ•˜์ง€๋งŒ PyTorch ์ตœ์‹  FSDP (v1.12+ ์ดํ›„) ๋Š” bucket size๋ฅผ ์ž‘๊ฒŒ ํ•˜๊ณ , ๋ ˆ์ด์–ด ๋‹จ์œ„๋กœ dynamicํ•˜๊ฒŒ all-gather๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ธฐ ๋•Œ๋ฌธ ์‹ค์ œ๋กœ๋Š” layer๋ณ„๋กœ ํ•„์š”ํ•œ shard๋งŒ GPU์— ์˜ฌ๋ ค ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ๋ถ„์‚ฐํ•œ๋‹ค.

 

 

 
์‹œ์  DDP FSDP
forward/backward ์ค‘ GPU๊ฐ€ ์ „์ฒด ํŒŒ๋ผ๋ฏธํ„ฐ + activations + gradients ์œ ์ง€ ํ•„์š” layer shard๋งŒ all-gather → ์—ฐ์‚ฐ → reduce-scatter, ์ „์ฒด๋ฅผ ๋™์‹œ์— ์˜ฌ๋ฆฌ์ง€ ์•Š์Œ
forward/backward ๋๋‚œ ํ›„ (steady) GPU๊ฐ€ ์ „์ฒด ํŒŒ๋ผ๋ฏธํ„ฐ ์œ ์ง€ GPU๋Š” ์ž์‹ ์˜ shard๋งŒ ์œ ์ง€, optimizer step๋„ shard ๊ธฐ๋ฐ˜

 

โœ… FSDP forward ํ๋ฆ„

[Layer1]
GPU0: shard0
GPU1: shard1
GPU2: shard2
GPU3: shard3

    ↓ all-gather (Layer1)
GPU0: weight1 ์ „์ฒด
GPU1: weight1 ์ „์ฒด
GPU2: weight1 ์ „์ฒด
GPU3: weight1 ์ „์ฒด

    ↓ forward ๊ณ„์‚ฐ (Layer1)
    ↓ reduce-scatter
GPU0: shard0 (Layer1)
GPU1: shard1 (Layer1)
GPU2: shard2 (Layer1)
GPU3: shard3 (Layer1)


[Layer2]
GPU0: shard0
GPU1: shard1
GPU2: shard2
GPU3: shard3

    ↓ all-gather (Layer2)
GPU0: weight2 ์ „์ฒด
GPU1: weight2 ์ „์ฒด
GPU2: weight2 ์ „์ฒด
GPU3: weight2 ์ „์ฒด

    ↓ forward ๊ณ„์‚ฐ (Layer2)
    ↓ reduce-scatter
GPU0: shard0 (Layer2)
GPU1: shard1 (Layer2)
GPU2: shard2 (Layer2)
GPU3: shard3 (Layer2)


[Layer3]
GPU0: shard0
GPU1: shard1
GPU2: shard2
GPU3: shard3

    ↓ all-gather (Layer3)
GPU0: weight3 ์ „์ฒด
GPU1: weight3 ์ „์ฒด
GPU2: weight3 ์ „์ฒด
GPU3: weight3 ์ „์ฒด

    ↓ forward ๊ณ„์‚ฐ (Layer3)
    ↓ reduce-scatter
GPU0: shard0 (Layer3)
GPU1: shard1 (Layer3)
GPU2: shard2 (Layer3)
GPU3: shard3 (Layer3)

...

 

3. FSDP ์‚ฌ์šฉ์„ ์œ„ํ•œ ์‚ฌ์ „ ์ค€๋น„

3.1 ํ™˜๊ฒฝ ์„ค์ •

  • PyTorch 1.12 ์ด์ƒ ํ•„์ˆ˜
  • CUDA, NCCL ์„ค์น˜ ํ•„์š” (GPU ๊ฐ„ ๋น ๋ฅธ ํ†ต์‹  ์ง€์›)
  • ๋ชจ๋“  GPU ์„œ๋ฒ„์— ๋™์ผํ•œ ์ฝ”๋“œ, ๋ฐ์ดํ„ฐ, ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜ ํ•„์š”

3.2 ๋ฉ€ํ‹ฐ GPU ํ™˜๊ฒฝ ์ค€๋น„

  • ๋‹จ์ผ ๋จธ์‹  ๋˜๋Š” ๋ฉ€ํ‹ฐ ๋จธ์‹ (๋ฉ€ํ‹ฐ ๋…ธ๋“œ) ํ™˜๊ฒฝ ์ค€๋น„
  • ๋ฉ€ํ‹ฐ ๋จธ์‹ ์ด๋ผ๋ฉด MASTER_ADDR, MASTER_PORT ๋“ฑ ์„ค์ • ํ•„์š”

 

4. PyTorch FSDP ์‚ฌ์šฉ ๋ฐฉ๋ฒ•

4.1 ๊ธฐ๋ณธ ์ฝ”๋“œ ๊ตฌ์กฐ

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, DistributedSampler

# ๋ถ„์‚ฐ ํ™˜๊ฒฝ ์ดˆ๊ธฐํ™”
dist.init_process_group(backend='nccl')

# ๋ชจ๋ธ ์ƒ์„ฑ ๋ฐ FSDP ์ ์šฉ
model = MyModel().cuda()
fsdp_model = FSDP(model)

# ๋ฐ์ดํ„ฐ ๋กœ๋” ์„ค์ •
dataset = MyDataset()
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)

optimizer = torch.optim.Adam(fsdp_model.parameters(), lr=1e-4)

# ํ•™์Šต ๋ฃจํ”„
for epoch in range(10):
    sampler.set_epoch(epoch)
    for inputs, targets in dataloader:
        inputs, targets = inputs.cuda(), targets.cuda()
        outputs = fsdp_model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# ๋ถ„์‚ฐ ํ™˜๊ฒฝ ์ •๋ฆฌ
dist.destroy_process_group()

4.2 ๋ฉ€ํ‹ฐ ๋จธ์‹  ํ™˜๊ฒฝ์—์„œ FSDP ์‹คํ–‰ ๋ฐฉ๋ฒ•

๋ฉ€ํ‹ฐ ๋จธ์‹  ํ™˜๊ฒฝ์ด๋ผ๋ฉด torchrun ๋˜๋Š” Slurm ๊ฐ™์€ job scheduler๋ฅผ ํ†ตํ•ด ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค.

 

torchrun ์˜ˆ์‹œ

torchrun --nnodes=2 --nproc_per_node=4 train.py

 

Slurm ์Šคํฌ๋ฆฝํŠธ ์˜ˆ์‹œ

#!/bin/bash
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4

srun torchrun --nnodes=$SLURM_JOB_NUM_NODES \
              --nproc_per_node=$SLURM_NTASKS_PER_NODE \
              train.py

 

FSDP ์‚ฌ์šฉ ์‹œ ์ฃผ์˜์‚ฌํ•ญ

  • FSDP๋Š” GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํฌ๊ฒŒ ์ ˆ์•ฝํ•˜์ง€๋งŒ, GPU ๊ฐ„ ํ†ต์‹ (all-gather, reduce-scatter)์ด ๋Š˜์–ด๋‚œ๋‹ค.
  • checkpoint ์ €์žฅ ์‹œ FSDP๊ฐ€ ๋‚˜๋ˆˆ shard๋“ค์„ ํ†ตํ•ฉํ•ด์„œ ์ €์žฅํ•ด์•ผ ํ•˜๋ฏ€๋กœ, checkpoint ๊ด€๋ฆฌ ๋ฐฉ๋ฒ•์— ์ฃผ์˜๊ฐ€ ํ•„์š”ํ•˜๋‹ค.

FSDP๋Š” ์ดˆ๋Œ€๊ทœ๋ชจ ๋ชจ๋ธ์„ ํ•™์Šตํ•  ๋•Œ GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํšจ์œจ์ ์œผ๋กœ ํ™œ์šฉํ•˜๊ธฐ ์œ„ํ•œ PyTorch์˜ ๊ฐ•๋ ฅํ•œ ๊ธฐ๋Šฅ์ด๋‹ค. ํŠนํžˆ ๋ชจ๋ธ ํฌ๊ธฐ๊ฐ€ ์ปค์„œ GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ๋ถ€์กฑํ•œ ์ƒํ™ฉ์—์„œ ๋งค์šฐ ํšจ๊ณผ์ ์ด๋ฉฐ, ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ์˜ ์žฅ์ ๊ณผ ๋ชจ๋ธ ๋ณ‘๋ ฌ์˜ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ์žฅ์ ์„ ๊ฒฐํ•ฉํ•œ ์ตœ์‹  ๋ถ„์‚ฐ ํ•™์Šต ๋ฐฉ๋ฒ•์ด๋‹ค.

๋ฐ˜์‘ํ˜•