PyTorch FSDP (Fully Sharded Data Parallel) ์๋ฒฝ ์ดํดํ๊ธฐ!
1. FSDP(Fully Sharded Data Parallel)์ด๋?
1.1 FSDP ๊ฐ๋
FSDP๋ PyTorch์์ ์ ๊ณตํ๋ ๊ณ ๊ธ ๋ถ์ฐ ํ์ต ๊ธฐ๋ฒ์ผ๋ก, ๋ชจ๋ธ์ ๋ชจ๋ ํ๋ผ๋ฏธํฐ๋ฅผ GPU๋ง๋ค ๋ณต์ ํ๋ ๊ธฐ์กด DDP ๋ฐฉ์๊ณผ ๋ฌ๋ฆฌ, ๋ชจ๋ธ์ ํ๋ผ๋ฏธํฐ๋ฅผ GPU๋ผ๋ฆฌ shard(์กฐ๊ฐ) ๋จ์๋ก ๋๋์ด ์ ์ฅํ๋ ๋ฐฉ์์ด๋ค. ์ด๋ฅผ ํตํด GPU ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ๋ํญ ์ ์ฝํ ์ ์๋ค.
FSDP๋ GPU๋ง๋ค ๋ชจ๋ธ ์ ์ฒด๊ฐ ์๋ ์ผ๋ถ shard๋ง ์ ์ฅํ๊ณ , forward ๋ฐ backward ์ฐ์ฐ ์ ํ์ํ ํ๋ผ๋ฏธํฐ๋ฅผ GPU ๊ฐ์ ์๋ก ๊ตํ(all-gather)ํ์ฌ ์ฐ์ฐ์ ์ํํ ํ ๋ค์ shard๋ก ๋ถ์ฐ ์ ์ฅ(reduce-scatter)ํ๋ ๋ฐฉ์์ผ๋ก ๋์ํ๋ค.
1.2 DDP vs FSDP ์ฐจ์ด
๊ตฌ๋ถ | DDP | FSDP |
๋ชจ๋ธ ํ๋ผ๋ฏธํฐ | ๊ฐ GPU๊ฐ ์ ์ฒด ๋ชจ๋ธ ๋ณต์ | ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ฅผ shard ๋จ์๋ก ๋๋ ์ ์ ์ฅ |
GPU ๋ฉ๋ชจ๋ฆฌ | GPU๋ง๋ค ๋ชจ๋ธ ์ ์ฒด๋ฅผ ์ ์ฅํด ๋ฉ๋ชจ๋ฆฌ ๋ง์ด ์ฌ์ฉ | GPU๋ผ๋ฆฌ shard ๋จ์๋ก ๋๋์ด ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ |
๋ฐ์ดํฐ ์ฒ๋ฆฌ | ๊ฐ GPU๊ฐ ๋ค๋ฅธ ๋ฐ์ดํฐ๋ฅผ ๋ณ๋ ฌ ์ฒ๋ฆฌ | ๊ฐ GPU๊ฐ ๋ค๋ฅธ ๋ฐ์ดํฐ๋ฅผ ๋ณ๋ ฌ ์ฒ๋ฆฌ (DDP์ ๋์ผ) |
GPU ๊ฐ ํต์ | gradient ๋๊ธฐํ(all-reduce)๋ง ์ํ | ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ๊ตํ(all-gather/reduce-scatter) ์ถ๊ฐ |
์ฆ, FSDP๋ ๋ชจ๋ธ ๋ณ๋ ฌ์ ๋ฉ๋ชจ๋ฆฌ ์ ์ฝ ์ฅ์ ๊ณผ ๋ฐ์ดํฐ ๋ณ๋ ฌ์ ๋น ๋ฅธ ๋ณ๋ ฌ์ฒ๋ฆฌ ์ฅ์ ์ ๊ฒฐํฉํ ๋ฐฉ์์ด๋ค.
2. FSDP ์ดํดํ๊ธฐ
GPU0 : shard 0
GPU1 : shard 1
GPU2 : shard 2
GPU3 : shard 3
๊ทธ๋์ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ฅผ ์์ ๊ฐ์ด ์ฌ๋ฌ GPU๊ฐ ๋๋ ์ ๊ฐ์ง๊ณ ์๋ค๋ ๊ฑด๋ฐ, ๊ทธ๋ผ ์ด๋ป๊ฒ ํ์ต์ ํ๋๊ฑธ๊น?
2.1 FSDP์ forward ๊ณผ์
- forward๋ฅผ ์์ํ๊ธฐ ์ ์,
- ํ์ํ layer (ํ๋ผ๋ฏธํฐ)๋ฅผ GPU๋ผ๋ฆฌ all-gather ํด์
→ ๊ฐ GPU๊ฐ forward ๊ณ์ฐ์ ํ์ํ ํ๋ผ๋ฏธํฐ๋ฅผ ์์๋ก ์ ๋ถ ๋ชจ์ - forward ๊ณ์ฐ์ด ๋๋๋ฉด,
- ๋ค์ ๊ทธ ํ๋ผ๋ฏธํฐ๋ค์ shard ๋จ์๋ก ๋ถ์ฐ ์ ์ฅ (reduce-scatter)ํด์
→ GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ค์ ๋น์๋๋ค.
์ฆ, ๊ฐ๋ณ GPU๊ฐ ๋ชจ๋ธ ์ ์ฒด๋ฅผ ํญ์ ๋ค๊ณ ์์ง ์์ง๋ง, forward ํ ๋๋ง ์ ์ GPU๋ผ๋ฆฌ ํ๋ผ๋ฏธํฐ๋ฅผ ๋ชจ์์ ๊ณ์ฐ์ ํ๊ณ , ๋๋๋ฉด ๋ค์ shard๋ก ๋๋ ์ ์ ์ฅํ๋ ๋ฐฉ์์ด๋ผ๊ณ ๋ณด๋ฉด ๋๋ค.
์ข๋ ํฌ๋ฉํ๊ฒ ์ ๋ฆฌํ๋ฉด -- GPU๊ฐ ๋ชจ๋ธ ์ ์ฒด๋ฅผ ํญ์ ๋ค๊ณ ์์ง ์์๋ forward ๊ณ์ฐ์ ํ ์ ์๋ ์ด์ ๋, ํ์ํ ๋๋ง ํ๋ผ๋ฏธํฐ๋ฅผ ๋คํธ์ํฌ๋ก ๋ชจ์์(all-gather) ๊ณ์ฐํ๊ธฐ ๋๋ฌธ์ด๋ค.
2.2 FSDP forward / backward ํต์ ํ๋ฆ
[ GPU0 ] [ GPU1 ] [ GPU2 ] [ GPU3 ]
| | | |
| | | |
|------ all-gather ํ๋ผ๋ฏธํฐ -----> |
| (๋ชจ๋ธ shard๋ฅผ ์๋ก ์ฃผ๊ณ ๋ฐ์) |
| => ๊ฐ GPU๊ฐ forward์ ํ์ํ |
| ์ ์ฒด ํ๋ผ๋ฏธํฐ๋ฅผ ์์๋ก ๋ณด์ |
| | | |
| forward & backward ๊ณ์ฐ |
| | | |
|<----- reduce-scatter ----------|
| (๊ณ์ฐ ๋๋ ํ ํ๋ผ๋ฏธํฐ๋ฅผ ๋ค์ |
| shard ๋จ์๋ก ๋๋ ์ ์ฅ) |
| | | |
2.3 ๊ฒฐ๊ตญ ์๊ฐ์ ์ด๋๋ผ๋ ๋ชจ๋ธ ์ ์ฒด๋ฅผ ๊ฐ๋ณ GPU์ ์ฌ๋ฆฌ๋ ๊ฒ ์๋๊ฐ??
๊ณต๋ถํ๋ค๋ณด๋ ์๋ฌธ์ด ๋ ๋ค. ๊ทธ๋ผ FSDP๋ฅผ ์ฐ๋ฉด ๊ฒฐ๊ตญ forward/backward ํ ๋ ๊ฐ๋ณ GPU๊ฐ ๋ชจ๋ธ ์ ์ฒด๋ฅผ ๋ค๊ณ ์๋ ์ ์ธ๋ฐ, ๊ทธ๋ผ ์ ๊ตณ์ด ์ค๋๋ฅผ ๋๋๋๊ฐ? ๋ชจ๋ธ ์ ์ฒด๋ฅผ ๋ค๊ณ ์์ ์ ์๋๋ฐ?? GPU ๊ฐ ํต์ ๋ง ๋ง์์ง๋๋ฐ ๋ญ๊ฐ ํจ์จ์ ์ด๋ผ๋ ๊ฑธ๊น??
โ ์ฐ์ ๋ฉ๋ชจ๋ฆฌ์ ์ฌ๋ผ๊ฐ๋ ๊ฐ๋ค์ ์์๋ณด์
์์ | ์์ฑ ์๊ธฐ | ์ ํ์ํ๊ฐ | GPU ๋ฉ๋ชจ๋ฆฌ ์ํฅ | ํน์ง |
weights | ๋ชจ๋ธ ์ด๊ธฐํ | forward ๊ณ์ฐ ๋ฐ ์ ๋ฐ์ดํธ | ์๋์ ์ผ๋ก ์์ | - ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ(W, b) - ํญ์ GPU์ ์์ฃผ |
gradients | backward | weights ์ ๋ฐ์ดํธ | weights์ ๋น์ทํจ | - ∂Loss/∂W ๊ณ์ฐ ๊ฒฐ๊ณผ - backward ํ optimizer๊ฐ ์ฌ์ฉ |
activations | forward ๊ณ์ฐ ์ค๊ฐ | backward ์ gradient chain rule ๊ณ์ฐ์ ํ์ | ๊ฐ์ฅ ํผ | - ๊ฐ layer์ ์ถ๋ ฅ๊ฐ - batch size, sequence length์ ๋ฐ๋ผ ๊ธ๊ฒฉํ ์ฆ๊ฐ |
์ฆ, ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๊ฐ weights์ด๊ณ forward๋ฅผ ํ๋ฉด์ ์์ฑ๋๋ ๊ฐ๋ค์ด activations์ด๊ณ , ์ด activations๋ฅผ ํ์ฉํด์ backward ํ ๋ gradients๋ฅผ ๊ตฌํ๋ ๊ฑฐ๋ผ ๋ณด๋ฉด ๋๋ค.
โ FSDP ์ค์ ์ฐ์ฐ ๊ณผ์
layer1 : ํ์ํ shard๋ง all-gather → forward → reduce-scatter
layer2 : ํ์ํ shard๋ง all-gather → forward → reduce-scatter
...
- ๊ฒฐ๋ก ๋ง ์๊ธฐํ๋ฉด FSDP๋ forward/backward ๋์ ๋ชจ๋ธ ์ ์ฒด๋ฅผ ํ๊บผ๋ฒ์ all-gather ํด์ GPU์ ์ฌ๋ฆฌ๋ ๋ฐฉ์์ด ์๋๋ค.
- FSDP๋ ๋ชจ๋ธ์ ์ฌ๋ฌ ๋ ์ด์ด ๋๋ ํ๋ผ๋ฏธํฐ bucket ๋จ์๋ก ์ชผ๊ฐ์ ๊ด๋ฆฌํ๋ค.
- ๊ทธ๋์ forward์์ ๋ ์ด์ด ํ๋๋ฅผ ๊ณ์ฐํ ๋
- ๊ทธ ๋ ์ด์ด์ ํ์ํ ํ๋ผ๋ฏธํฐ shard๋ง GPU๋ผ๋ฆฌ all-gather ํด์ ์ ์ฒด ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ๊ตฌ์ฑ
- forward ์ฐ์ฐ์ ์ํ
- ์ฐ์ฐ์ด ๋๋๋ฉด ๋ค์ reduce-scatter ํด์ shard๋ง ๋จ๊ธฐ๊ณ ๋๋จธ์ง ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋น์ด๋ค.
- backward๋ ๊ฐ์ ๋ฐฉ์์ผ๋ก ์ญ์์ผ๋ก ์ํํ๋ค.
๊ทธ๋์ forward/backward ๋์์๋ GPU๊ฐ ๋ชจ๋ธ ์ ์ฒด๋ฅผ ๋ฉ๋ชจ๋ฆฌ์ ๋์์ ์ฌ๋ฆฌ๊ณ ์๋ ์๊ฐ์ ์๋ค. ๋ชจ๋ธ ํ๋๋ฅผ ํต์งธ๋ก GPU์ ์ฌ๋ฆด ์ ์๋ ๊ฒฝ์ฐ์๋, FSDP๋ layer-by-layer ๋๋ bucket-by-bucket ์ผ๋ก ๊ณ์ฐ์ ์ด์ด๋๊ฐ๋ฉด์ ํ์ต์ด ๊ฐ๋ฅํด์ง๋ค.
FSDP๊ฐ ๋ ๊ฑฐ์ bucket ๋ฐฉ์์ผ๋ก ๋ฌถ์ ํ๋ผ๋ฏธํฐ๋ฅผ chunk ๋จ์๋ก all-gather ํ๋๋ฐ, ์ด chunk size๋ฅผ ํฌ๊ฒ ์ก์ผ๋ฉด ์ฌ์ค์ forward ์ด๋ฐ์ ๋ง์ ํ๋ผ๋ฏธํฐ๋ฅผ ํ๊บผ๋ฒ์ GPU๋ก ๋ชจ์์ ์ฐ์ฐํ๊ธฐ ๋๋ฌธ์ peak๊ฐ ์ปค์ ธ์ DDP์ ์ ์ฌํ ๋ฉ๋ชจ๋ฆฌ ํจํด์ ๋ณด์ด๋ ๊ฒฝ์ฐ๊ฐ ์๋ค. ๊ทธ๋์ ์ด๊ฑธ ๋จ์ํํด “forward/backward ๋ GPU๊ฐ ์ ์ฒด๋ฅผ ๋ค๊ณ ์๋ค”๊ณ ๋ค ํํ ๋งํจ.
ํ์ง๋ง PyTorch ์ต์ FSDP (v1.12+ ์ดํ) ๋ bucket size๋ฅผ ์๊ฒ ํ๊ณ , ๋ ์ด์ด ๋จ์๋ก dynamicํ๊ฒ all-gather๋ฅผ ์ํํ๊ธฐ ๋๋ฌธ ์ค์ ๋ก๋ layer๋ณ๋ก ํ์ํ shard๋ง GPU์ ์ฌ๋ ค ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ๋ถ์ฐํ๋ค.
์์ | DDP | FSDP |
forward/backward ์ค | GPU๊ฐ ์ ์ฒด ํ๋ผ๋ฏธํฐ + activations + gradients ์ ์ง | ํ์ layer shard๋ง all-gather → ์ฐ์ฐ → reduce-scatter, ์ ์ฒด๋ฅผ ๋์์ ์ฌ๋ฆฌ์ง ์์ |
forward/backward ๋๋ ํ (steady) | GPU๊ฐ ์ ์ฒด ํ๋ผ๋ฏธํฐ ์ ์ง | GPU๋ ์์ ์ shard๋ง ์ ์ง, optimizer step๋ shard ๊ธฐ๋ฐ |
โ FSDP forward ํ๋ฆ
[Layer1]
GPU0: shard0
GPU1: shard1
GPU2: shard2
GPU3: shard3
↓ all-gather (Layer1)
GPU0: weight1 ์ ์ฒด
GPU1: weight1 ์ ์ฒด
GPU2: weight1 ์ ์ฒด
GPU3: weight1 ์ ์ฒด
↓ forward ๊ณ์ฐ (Layer1)
↓ reduce-scatter
GPU0: shard0 (Layer1)
GPU1: shard1 (Layer1)
GPU2: shard2 (Layer1)
GPU3: shard3 (Layer1)
[Layer2]
GPU0: shard0
GPU1: shard1
GPU2: shard2
GPU3: shard3
↓ all-gather (Layer2)
GPU0: weight2 ์ ์ฒด
GPU1: weight2 ์ ์ฒด
GPU2: weight2 ์ ์ฒด
GPU3: weight2 ์ ์ฒด
↓ forward ๊ณ์ฐ (Layer2)
↓ reduce-scatter
GPU0: shard0 (Layer2)
GPU1: shard1 (Layer2)
GPU2: shard2 (Layer2)
GPU3: shard3 (Layer2)
[Layer3]
GPU0: shard0
GPU1: shard1
GPU2: shard2
GPU3: shard3
↓ all-gather (Layer3)
GPU0: weight3 ์ ์ฒด
GPU1: weight3 ์ ์ฒด
GPU2: weight3 ์ ์ฒด
GPU3: weight3 ์ ์ฒด
↓ forward ๊ณ์ฐ (Layer3)
↓ reduce-scatter
GPU0: shard0 (Layer3)
GPU1: shard1 (Layer3)
GPU2: shard2 (Layer3)
GPU3: shard3 (Layer3)
...
3. FSDP ์ฌ์ฉ์ ์ํ ์ฌ์ ์ค๋น
3.1 ํ๊ฒฝ ์ค์
- PyTorch 1.12 ์ด์ ํ์
- CUDA, NCCL ์ค์น ํ์ (GPU ๊ฐ ๋น ๋ฅธ ํต์ ์ง์)
- ๋ชจ๋ GPU ์๋ฒ์ ๋์ผํ ์ฝ๋, ๋ฐ์ดํฐ, ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ค์น ํ์
3.2 ๋ฉํฐ GPU ํ๊ฒฝ ์ค๋น
- ๋จ์ผ ๋จธ์ ๋๋ ๋ฉํฐ ๋จธ์ (๋ฉํฐ ๋ ธ๋) ํ๊ฒฝ ์ค๋น
- ๋ฉํฐ ๋จธ์ ์ด๋ผ๋ฉด MASTER_ADDR, MASTER_PORT ๋ฑ ์ค์ ํ์
4. PyTorch FSDP ์ฌ์ฉ ๋ฐฉ๋ฒ
4.1 ๊ธฐ๋ณธ ์ฝ๋ ๊ตฌ์กฐ
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader, DistributedSampler
# ๋ถ์ฐ ํ๊ฒฝ ์ด๊ธฐํ
dist.init_process_group(backend='nccl')
# ๋ชจ๋ธ ์์ฑ ๋ฐ FSDP ์ ์ฉ
model = MyModel().cuda()
fsdp_model = FSDP(model)
# ๋ฐ์ดํฐ ๋ก๋ ์ค์
dataset = MyDataset()
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)
optimizer = torch.optim.Adam(fsdp_model.parameters(), lr=1e-4)
# ํ์ต ๋ฃจํ
for epoch in range(10):
sampler.set_epoch(epoch)
for inputs, targets in dataloader:
inputs, targets = inputs.cuda(), targets.cuda()
outputs = fsdp_model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ๋ถ์ฐ ํ๊ฒฝ ์ ๋ฆฌ
dist.destroy_process_group()
4.2 ๋ฉํฐ ๋จธ์ ํ๊ฒฝ์์ FSDP ์คํ ๋ฐฉ๋ฒ
๋ฉํฐ ๋จธ์ ํ๊ฒฝ์ด๋ผ๋ฉด torchrun ๋๋ Slurm ๊ฐ์ job scheduler๋ฅผ ํตํด ์คํํ ์ ์๋ค.
torchrun ์์
torchrun --nnodes=2 --nproc_per_node=4 train.py
Slurm ์คํฌ๋ฆฝํธ ์์
#!/bin/bash
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4
srun torchrun --nnodes=$SLURM_JOB_NUM_NODES \
--nproc_per_node=$SLURM_NTASKS_PER_NODE \
train.py
FSDP ์ฌ์ฉ ์ ์ฃผ์์ฌํญ
- FSDP๋ GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํฌ๊ฒ ์ ์ฝํ์ง๋ง, GPU ๊ฐ ํต์ (all-gather, reduce-scatter)์ด ๋์ด๋๋ค.
- checkpoint ์ ์ฅ ์ FSDP๊ฐ ๋๋ shard๋ค์ ํตํฉํด์ ์ ์ฅํด์ผ ํ๋ฏ๋ก, checkpoint ๊ด๋ฆฌ ๋ฐฉ๋ฒ์ ์ฃผ์๊ฐ ํ์ํ๋ค.
FSDP๋ ์ด๋๊ท๋ชจ ๋ชจ๋ธ์ ํ์ตํ ๋ GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํจ์จ์ ์ผ๋ก ํ์ฉํ๊ธฐ ์ํ PyTorch์ ๊ฐ๋ ฅํ ๊ธฐ๋ฅ์ด๋ค. ํนํ ๋ชจ๋ธ ํฌ๊ธฐ๊ฐ ์ปค์ GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ถ์กฑํ ์ํฉ์์ ๋งค์ฐ ํจ๊ณผ์ ์ด๋ฉฐ, ๋ฐ์ดํฐ ๋ณ๋ ฌ์ ์ฅ์ ๊ณผ ๋ชจ๋ธ ๋ณ๋ ฌ์ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ ์ฅ์ ์ ๊ฒฐํฉํ ์ต์ ๋ถ์ฐ ํ์ต ๋ฐฉ๋ฒ์ด๋ค.