๋ฅ๋ฌ๋ ๋ชจ๋ธ์ด ์ ์ ์ปค์ง๊ณ ๋ฐ์ดํฐ๋ ๋ฐฉ๋ํด์ง๋ฉด์, ๋จ์ผ GPU๋ ์๋ฒ๋ง์ผ๋ก๋ ํ์ต ์๋๊ฐ ๋๋ฌด ๋๋ฆฌ๊ฑฐ๋ GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ถ์กฑํด ํ์ต์ด ๋ถ๊ฐ๋ฅํด์ง๋ค. ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ฌ๋ฌ GPU๋ฅผ ๋์์ ํ์ฉํด ๋ชจ๋ธ์ ํ์ต์ํค๋ ๊ฒ์ด ๋ฐ๋ก ๋ถ์ฐ ํ์ต์ด๋ค.
1. ๋ถ์ฐ ํ์ต ์ข ๋ฅ
1.1 ๋ฐ์ดํฐ ๋ณ๋ ฌํ(Data Parallelism)
[์ ์ฒด ๋ฐ์ดํฐ] → [๋ถํ ๋ ๋ฏธ๋๋ฐฐ์น1] → GPU0 (๋ชจ๋ธ ๋ณต์ )
→ [๋ถํ ๋ ๋ฏธ๋๋ฐฐ์น2] → GPU1 (๋ชจ๋ธ ๋ณต์ )
→ [๋ถํ ๋ ๋ฏธ๋๋ฐฐ์น3] → GPU2 (๋ชจ๋ธ ๋ณต์ )
[๊ฐ GPU] → forward & backward → all-reduce → ๋๊ธฐํ → ํ๋ผ๋ฏธํฐ ์
๋ฐ์ดํธ
๊ฐ์ฅ ๋ณดํธ์ ์ผ๋ก ์ฌ์ฉ๋๋ ๋ฐฉ์์ด๋ค. ๋์ผํ ๋ชจ๋ธ์ ์ฌ๋ฌ GPU์ ๋ณต์ ํ๊ณ , ๋ฏธ๋๋ฐฐ์น ๋ฐ์ดํฐ๋ฅผ GPU๋ณ๋ก ๋๋์ด ์ฒ๋ฆฌํ๋ค. ๊ฐ GPU์์ forward์ backward๋ฅผ ๊ณ์ฐํ ๋ค, all-reduce ์ฐ์ฐ์ผ๋ก ๊ทธ๋๋์ธํธ๋ฅผ ํ๊ท ๋ด๊ณ ๋๊ธฐํํด ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธํ๋ค.
PyTorch์๋ DataParallel(DP)๊ณผ DistributedDataParallel(DDP) ๋ ๊ฐ์ง๊ฐ ์๋ค. DP๋ ๋จ์ผ ๋จธ์ ์์ ์ฌ๋ฌ GPU์ ๋ชจ๋ธ์ ๋ณต์ ํด ๋ฐ์ดํฐ๋ง ๋๋ ๋ฃ๋ ๋ฐฉ์์ด๋ฉฐ ๊ตฌํ์ด ๊ฐ๋จํ์ง๋ง, Python GIL(Global Interpreter Lock)๊ณผ single-process ๊ตฌ์กฐ๋ก ์ธํด ํต์ ๋ณ๋ชฉ์ด ์ฌํ๋ค. ๊ทธ๋์ ์ค์ ๋ก๋ ๊ฑฐ์ ํญ์ DDP๋ฅผ ์ฌ์ฉํ๋ค. DDP๋ ๊ฐ GPU๋ง๋ค ๋ณ๋์ ํ๋ก์ธ์ค๋ฅผ ์์ฑํด ํต์ ๋ณ๋ชฉ์ ์ค์ด๊ณ ํจ์จ์ ์ผ๋ก all-reduce๋ฅผ ์ํํด ๊ทธ๋๋์ธํธ๋ฅผ ๋๊ธฐํํ๋ค.
- ์์: PyTorch DDP, TensorFlow MirroredStrategy
- ์ฅ์ : ๊ตฌํ์ด ๋จ์ํ๊ณ GPU๋ฅผ ๋๋ฆฌ๊ธฐ ์ฝ๋ค.
- ๋จ์ : ๋ชจ๋ธ ์ ์ฒด๊ฐ ๊ฐ GPU์ ๋ณต์ ๋๋ฏ๋ก, GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ถ์กฑํ๋ฉด ๋ถ๊ฐ๋ฅํ๋ค.
1.2 ๋ชจ๋ธ ๋ณ๋ ฌํ(Model Parallelism)
[์
๋ ฅ ๋ฐ์ดํฐ] → GPU0 (Model Layer1-5) → GPU1 (Model Layer6-10) → GPU2 (Model Layer11-15)
forward → GPU ๊ฐ ์ฐ์ ์ ๋ฌ → backward → GPU ๊ฐ ์ฐ์ ์ ๋ฌ
๋ชจ๋ธ ์์ฒด๋ฅผ ์ฌ๋ฌ GPU์ ๋๋ ์ ์ฅํ๊ณ ์์ฐจ์ ์ผ๋ก forward ์ฐ์ฐ์ GPU๋ฅผ ๊ฑฐ์ณ ์งํํ๋ค. GPT-3, T5 ๊ฐ์ ํ๋์ GPU์ ์ฌ๋ผ๊ฐ์ง ์์ ์ ๋๋ก ๊ฑฐ๋ํ ๋ชจ๋ธ์์ ์ฃผ๋ก ์ฌ์ฉํ๋ค. PyTorch์์๋ ํน์ ๋ ์ด์ด๋ฅผ ์๋์ผ๋ก cuda:0, cuda:1์ ์ฌ๋ฆฌ๊ฑฐ๋, Megatron-LM๊ณผ ๊ฐ์ ํ๋ ์์ํฌ๊ฐ Tensor Parallelism์ ํตํด ์๋์ผ๋ก ๋ ์ด์ด๋ฅผ ๋๋ ์ค๋ค.
- ์์: Megatron-LM (Tensor Parallelism), manual PyTorch split
- ์ฅ์ : ์ด๋๊ท๋ชจ ๋ชจ๋ธ์ ์ฌ๋ฌ GPU์ ๋๋ ์ ์ฒ๋ฆฌ ๊ฐ๋ฅํ๋ค.
- ๋จ์ : GPU ๊ฐ ํต์ ๋์ด ๋ง์ latency์ bandwidth ๋ณ๋ชฉ์ด ๋ฐ์ํ๊ธฐ ์ฝ๋ค. Layer๊ฐ ์ข ์์ผ๋ก ์ธํด ๋ณ๋ ฌ์ฑ์ด ์ ํ๋๋ค.
1.3 ํ์ดํ๋ผ์ธ ๋ณ๋ ฌํ(Pipeline Parallelism)
[๋ฏธ๋๋ฐฐ์น1] → GPU0 (Stage1) → GPU1 (Stage2) → GPU2 (Stage3)
[๋ฏธ๋๋ฐฐ์น2] → → GPU0 (Stage1) → GPU1 (Stage2) → GPU2 (Stage3)
[๋ฏธ๋๋ฐฐ์น3] → → GPU0 (Stage1) → GPU1 (Stage2) → GPU2 (Stage3)
→ ํ์ดํ๋ผ์ธ ์ฑ์์ bubble ์ต์ํ
๋ชจ๋ธ์ ์ฌ๋ฌ stage๋ก ๋๋ GPU์ ๋ฐฐ์นํ๊ณ , ๋ฐ์ดํฐ๋ฅผ ์ฐ์์ ์ผ๋ก ํ๋ ค๋ณด๋ด ์ฒ๋ฆฌํ๋ ๋ฐฉ์์ด๋ค. ๋ฏธ๋๋ฐฐ์น๋ฅผ ๋ ์๊ฒ ์ชผ๊ฐ pipeline์ ์ฑ์ idle time(bubble)์ ์ค์ธ๋ค. ๊ตฌ์กฐ ์์ฒด๋ ๋ชจ๋ธ ๋ณ๋ ฌํ์ฒ๋ผ layer๋ฅผ ์ฌ๋ฌ GPU์ ๋๋ ๋์ง๋ง, GPU0์์ ๋ฏธ๋๋ฐฐ์น1์ ์ฐ์ฐํ๊ณ GPU1๋ก ๋๊ธด ๋ค GPU0์ ๋ฐ๋ก ๋ฏธ๋๋ฐฐ์น2๋ฅผ ์ฐ์ฐํ๊ธฐ ์์ํด GPU๋ค์ด ์ฌ์ง ์๋๋ก ๋ง๋๋ ๋ฐฉ์์ด๋ค. Deepspeed Pipeline์ด๋ PyTorch torch.distributed.pipeline.sync.Pipe๊ฐ ์ด๋ฅผ ์ง์ํ๋ค.
- ์์: Deepspeed Pipeline, PyTorch Pipe
- ์ฅ์ : layer๋ฅผ stage๋ณ๋ก ๋ถ๋ฆฌํด ๋ฉ๋ชจ๋ฆฌ ๋ถ๋ด์ ๋ ์ธ๋ฐํ ๋ถ์ฐํ ์ ์๋ค.
- ๋จ์ : ์์ฐจ ์ฒ๋ฆฌ ๊ตฌ์กฐ๋ผ์ latency๊ฐ ์ฆ๊ฐํ๋ฉฐ, bubble์ด ๋ฐ์ํ์ง ์๋๋ก carefulํ๊ฒ microbatch๋ฅผ ์กฐ์ ํด์ผ ํ๋ค.
์ด๋ค ๋ถ์ฐ ํ์ต์ ์ธ์ ์ฐ๋?
- ์์~์ค๊ฐ ํฌ๊ธฐ ๋ชจ๋ธ (์์ฒ๋ง~์์ต ํ๋ผ๋ฏธํฐ): ๋จ์ํ DDP๋ง์ผ๋ก ์ถฉ๋ถ. GPU๋ฅผ ์ฌ๋ฌ ๊ฐ ์ธ์๋ก ํ์ต ์๋๊ฐ ์ ํ์ ์ผ๋ก ์ฆ๊ฐํ๋ค.
- ํ๋์ GPU ๋ฉ๋ชจ๋ฆฌ์ ์ฌ๋ผ๊ฐ์ง ์๋ ๋ชจ๋ธ (์์ญ์ต ํ๋ผ๋ฏธํฐ): ๋ชจ๋ธ ๋ณ๋ ฌ ๋๋ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ์ ์ ์ฉํด GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ถ์ฐ.
- ์์ญ์ต~์๋ฐฑ์ต ํ๋ผ๋ฏธํฐ ์ด์ ์ด๊ฑฐ๋ ๋ชจ๋ธ: DDP + ๋ชจ๋ธ ๋ณ๋ ฌ + ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ์ ํจ๊ป ์กฐํฉํ๋ค. Megatron-LM, Deepspeed ZeRO-Infinity ๊ฐ์ ํ๋ ์์ํฌ๊ฐ ์ด๋ฅผ ์๋ํํด์ค๋ค.
๋ฐ๋ผ์ ๋ชจ๋ธ ํฌ๊ธฐ, GPU ๋ฉ๋ชจ๋ฆฌ, ๋คํธ์ํฌ ๋์ญํญ์ ๊ณ ๋ คํด ์ ์ ํ ๋ถ์ฐ ๋ฐฉ์ ๋๋ ์กฐํฉ์ ์ ํํ๋ ๊ฒ์ด ์ค์ํ๋ค.
2. PyTorch DistributedDataParallel(DDP) ์ฌ์ฉํ๊ธฐ
PyTorch DistributedDataParallel(DDP)์ ๋ฐ์ดํฐ ๋ณ๋ ฌ(Data Parallelism)์ ์ํ ๊ฐ์ฅ ํ์ค์ ์ธ ๋ฐฉ๋ฒ์ด๋ค. ์ฌ๋ฌ GPU์ ๋์ผํ ๋ชจ๋ธ์ ์ฌ๋ฆฌ๊ณ , ๊ฐ GPU๊ฐ ์๋ก ๋ค๋ฅธ ๋ฐ์ดํฐ ๋ฏธ๋๋ฐฐ์น๋ฅผ ํ์ตํ ๋ค, all-reduce๋ฅผ ํตํด ๊ทธ๋๋์ธํธ๋ฅผ ๋๊ธฐํํด ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธํ๋ค.
๋จ์ผ ๋จธ์ (์ฌ๋ฌ GPU)์์๋ ์ธ ์ ์๊ณ , ๋ฉํฐ ๋ ธ๋(์ฌ๋ฌ ์๋ฒ)์์๋ ๋์ผํ ์ฝ๋ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ ์ ์๋ค.
all-reduce๋ ๊ฐ GPU์์ ๊ณ์ฐํ ๊ทธ๋๋์ธํธ๋ฅผ ์๋ก ์ฃผ๊ณ ๋ฐ์ ํ๊ท ์ ๊ตฌํ๊ณ ๋๊ธฐํํ๋ ํต์ ์ฐ์ฐ์ด๋ค. ์ด๋ฅผ ํตํด ๋ชจ๋ GPU๊ฐ ๋์ผํ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ์งํ๋ฉด์ ํ์ต์ ์งํํ ์ ์๋ค. ์ฃผ๋ก NCCL๋ก ๊ณ ์ ์ํ๋๋ฉฐ, DDP์์ ์๋์ผ๋ก ์ํ๋๋ค.
2.1 ์ฌ์ ์ค๋น
- ๋ชจ๋ ๋จธ์ (๋ ธ๋)์ ๋์ผํ ์ฝ๋, ๋ฐ์ดํฐ, ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ด ์์ด์ผ ํ๋ค.
- CUDA์ NCCL์ด ์ค์น๋์ด ์์ด์ผ ํ๋ค. (PyTorch์ backend๋ก ์ฃผ๋ก nccl์ ์ฌ์ฉ)
- ๋จ์ผ ๋จธ์ ์ด๋ฉด CUDA_VISIBLE_DEVICES๋ก GPU๋ฅผ ๊ด๋ฆฌํ๋ฉด ๋๊ณ , ๋ฉํฐ ๋ ธ๋๋ผ๋ฉด MASTER_ADDR, MASTER_PORT ํ๊ฒฝ๋ณ์๋ฅผ ์ค์ ํด์ผ ํ๋ค.
๋ฉํฐ ๋ ธ๋ ํ๊ฒฝ์์๋ ๋ณดํต MASTER_ADDR์ rank=0 ๋จธ์ ์ IP๋ก ์ค์ ํ๋ค.
export MASTER_ADDR=192.168.1.10
export MASTER_PORT=12355
2.2 DDP ๊ธฐ๋ณธ ์ฝ๋ ๊ตฌ์กฐ
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def train(rank, world_size):
setup(rank, world_size)
device = torch.device(f'cuda:{rank}')
model = MyModel().to(device)
ddp_model = DDP(model, device_ids=[rank])
dataset = MyDataset()
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(10):
sampler.set_epoch(epoch)
for batch in dataloader:
inputs, targets = batch
inputs, targets = inputs.to(device), targets.to(device)
outputs = ddp_model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
cleanup()
์ด ์คํฌ๋ฆฝํธ๋ GPU 0, GPU 1, GPU 2... ๊ฐ๊ฐ์ด ๋ ๋ฆฝ์ ์ธ Python ํ๋ก์ธ์ค๋ก ์คํ๋์ด, ๋์ผํ ๋ชจ๋ธ์ ๊ฐ GPU์์ ํ์ตํ๊ณ all-reduce๋ฅผ ํตํด ๊ทธ๋๋์ธํธ๋ฅผ ๋๊ธฐํํ๋ค.
2.3 ์ค์ ์คํํ๊ธฐ
โ ๋จ์ผ ๋จธ์
import torch.multiprocessing as mp
def main():
world_size = 4
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
if __name__ == '__main__':
main()
๋จ์ผ ๋จธ์ ์์ GPU๊ฐ 4๊ฐ๋ผ๋ฉด torch.multiprocessing.spawn์ ์ฌ์ฉํด ์์ ๊ฐ์ด ์คํํ ์ ์๋ค.
โ ๋ฉํฐ ๋จธ์
๋ฉํฐ ๋จธ์ ์์๋ ๋ชจ๋ ๋จธ์ ์์ ๋์ผํ๊ฒ export์ torchrun์ ์คํํด์ผ ํ๋ค. ์ฆ, master(๋ณดํต rank=0) ๋จธ์ ์์๋ง ํ๊ฒฝ ๋ณ์๋ฅผ ์ค์ ํ๊ฑฐ๋ torchrun์ ๋๋ฆฌ๋ ๊ฒ์ด ์๋๋ผ, ๋ชจ๋ ๋จธ์ ์์ ํ๊ฒฝ ๋ณ์๋ฅผ ๋์ผํ๊ฒ ์ค์ ํ๊ณ torchrun์ ๊ฐ๊ฐ ์คํํด์ผ ํ๋ค.
์๋ฅผ ๋ค์ด ๋จธ์ ์ด 3๋ ์๊ณ , ๊ฐ๊ฐ GPU๊ฐ 4๊ฐ์ฉ ์๋ ๊ฒฝ์ฐ ๋ค์๊ณผ ๊ฐ์ด ํ๋ค.
export MASTER_ADDR=192.168.1.10
export MASTER_PORT=12355
# ๋จธ์ 1 (rank=0)
torchrun --nnodes=3 --nproc_per_node=4 --node_rank=0 train.py
# ๋จธ์ 2 (rank=1)
torchrun --nnodes=3 --nproc_per_node=4 --node_rank=1 train.py
# ๋จธ์ 3 (rank=2)
torchrun --nnodes=3 --nproc_per_node=4 --node_rank=2 train.py
- ๋ชจ๋ ๋จธ์ ์์ ๋์ผํ๊ฒ export ํ๋ค. MASTER_ADDR์ master ์ญํ ์ ํ rank=0 ๋จธ์ ์ IP๋ฅผ ์ ๋๋ค.
- “ํต์ ์ ์ํ ๊ธฐ์ค์ด ๋๋ ๋ง์คํฐ IP(=rank=0 ๋จธ์ )”๋ฅผ ์๋ ค์ฃผ๋ ๊ฒ
- ๊ทธ๋ฆฌ๊ณ ๊ฐ ๋จธ์ ์์ ์์ ์ node_rank์ ๋ง๊ฒ torchrun์ ์คํํ๋ค.
- ์ฆ, ๊ฐ ๋จธ์ ์ด ๋ชจ๋ ๋์์ ์ด ๋ช ๋ น์ ์คํํด์ผ DDP๊ฐ ์ ์์ ์ผ๋ก ํต์ ์ ์์ํ๋ค.
- ๋ง์คํฐ ๋จธ์ ๋ง ์คํํ๋ ๊ฒ์ด ์๋๋ฉฐ, ๋ชจ๋ ๋จธ์ ์ด ์์ ์ node_rank๋ฅผ ์ง์ ํด ๋์ผํ train.py๋ฅผ ์คํํด์ผ ํ๋ค.
2.4 ์ฃผ์ ๊ฐ๋ ๊ณผ ์ฃผ์์ฌํญ
- rank ๋ ์ ์ฒด ํ๋ก์ธ์ค์์ ๊ณ ์ ๋ฒํธ๋ฅผ ๋งํ๋ค. ์๋ฅผ ๋ค์ด 8 GPU๋ฉด rank=0~7๊น์ง ์๋ค.
- world_size ๋ ์ ์ฒด GPU(=์ ์ฒด ํ๋ก์ธ์ค) ์๋ฅผ ์๋ฏธํ๋ค.
- DistributedSampler ๋ ๊ฐ ํ๋ก์ธ์ค๊ฐ ๊ฐ์ ๋ฐ์ดํฐ์ ์ ์๋ก ๋ค๋ฅธ ์์/๋ฒ์๋ก ์ฝ๊ฒ ํ๋ค. epoch๋ง๋ค set_epoch(epoch)๋ฅผ ํธ์ถํด์ผ ๋ฐ์ดํฐ ์ ํ๋ง์ด ์ ๋์ํ๋ค.
- DDP๋ DataParallel(DP)๊ณผ ๋ฌ๋ฆฌ Python GIL ๋ณ๋ชฉ ์์ด ๊ฐ ํ๋ก์ธ์ค๊ฐ GPU ํ๋์ฉ์ ์ ๋ดํ๋ฏ๋ก ํจ์ฌ ํจ์จ์ ์ด๋ค.
PyTorch DDP๋ ๋จ์ผ ๋จธ์ , ๋ฉํฐ ๋จธ์ ๋ชจ๋์์ ์ฌ์ฉํ ์ ์๋ ๊ฐ๋ ฅํ๊ณ ํ์ค์ ์ธ ๋ฐ์ดํฐ ๋ณ๋ ฌ ํ๋ ์์ํฌ์ด๋ค. setup → DDP๋ก ๋ชจ๋ธ ๊ฐ์ธ๊ธฐ → DistributedSampler → backward์์ all-reduce ์์ผ๋ก ๋์ํ๋ค. ์ด๋ฅผ ํตํด ๋์ผํ ๋ชจ๋ธ์ GPU๋ง๋ค ๋ ๋ฆฝ์ ์ผ๋ก ํ์ตํ๊ณ , ๊ทธ๋๋์ธํธ๋ฅผ ์๋์ผ๋ก ํต์ ·๋๊ธฐํํด ํจ์จ์ ์ผ๋ก ๋๊ท๋ชจ ๋ชจ๋ธ์ ํ์ตํ ์ ์๋ค.
3. Slurm์์ ๋ฉํฐ ๋จธ์ PyTorch DDP ์คํํ๊ธฐ
๋ฉํฐ ๋ ธ๋ ํ๊ฒฝ์์ Slurm์ ์ฌ์ฉํ๋ฉด srun์ผ๋ก ์ฌ๋ฌ ๋จธ์ ์์ ๋์์ torchrun์ ์คํํ ์ ์์ด ํธ๋ฆฌํ๋ค. Slurm์ด ์๋์ผ๋ก MASTER_ADDR, RANK, WORLD_SIZE ๊ฐ์ ํ๊ฒฝ ๋ณ์๋ฅผ ์ธํ ํด ์ฃผ๊ธฐ ๋๋ฌธ์ ๋ฐ๋ก export ํ ํ์๋ ์๋ค.
ํ์ง๋ง, ์ฌ๋ผ์ GPU ์๋ฒ๋ฅผ ํด๋ฌ์คํฐ๋ก ๊ตฌ์ฑํด ๋์ ํ๊ฒฝ์์๋ง ์ธ ์ ์๋ ๋๊ตฌ๋ผ ๊ทธ๋ฅ EC2 ์ฌ๋ฌ ๋๋ฅผ ๋์ด๋ค๊ณ ๊ณง์ฅ ์ฌ์ฉํ ์ ์๋ ๊ฒ ์๋๊ณ , ๋ฐ๋ก Slurm ์ํ๊ณ(=slurmd, slurmctld, config)๋ฅผ ์ค์น + ์ค์ ํด์ผ ํ๋ค.
3.1 Slurm ์คํฌ๋ฆฝํธ ์์ (multi-node DDP)
์๋๋ Slurm job ์คํฌ๋ฆฝํธ(train_job.sh) ์์ ์ด๋ค.
#!/bin/bash
#SBATCH --job-name=ddp_train
#SBATCH --nodes=3 # ์ฌ์ฉํ ๋
ธ๋ ์
#SBATCH --ntasks-per-node=4 # ๋
ธ๋๋น GPU ๊ฐ์
#SBATCH --gres=gpu:4 # ๋
ธ๋๋น GPU ํ ๋น
#SBATCH --cpus-per-task=4
#SBATCH --time=24:00:00
#SBATCH --partition=gpu
module load cuda/12.1
module load python/3.11
srun torchrun \
--nnodes=$SLURM_JOB_NUM_NODES \
--nproc_per_node=$SLURM_NTASKS_PER_NODE \
--node_rank=$SLURM_NODEID \
train.py
#SBATCH --nodes=3, #SBATCH --ntasks-per-node=4 ์ด๋ผ๊ณ ์ค์ ํ์ผ๋ฏ๋ก
| ์ฌ๋ผ ๋ณ์ | ์ค์ ์์ ๊ฐ |
| $SLURM_JOB_NUM_NODES | 3 |
| $SLURM_NTASKS_PER_NODE | 4 |
| $SLURM_NODEID | ๊ฐ ๋ ธ๋์์ 0,1,2 |
์ฌ๋ผ ๋ณ์๋ ์์ ๊ฐ์ด ์ธํ ๋๊ณ , ๊ทธ๋ฌ๋ฉด srun์ด ๊ฐ ๋ ธ๋์์ ์ด๋ ๊ฒ ์คํํ๊ฒ ๋ฉ๋๋ค.
# ๋
ธ๋1์์
torchrun --nnodes=3 --nproc_per_node=4 --node_rank=0 train.py
# ๋
ธ๋2์์
torchrun --nnodes=3 --nproc_per_node=4 --node_rank=1 train.py
# ๋
ธ๋3์์
torchrun --nnodes=3 --nproc_per_node=4 --node_rank=2 train.py
ํน์ ๋ ธ๋๋ฅผ ์ฌ์ฉํด์ job์ ์คํํ๊ณ ์ถ์ผ๋ฉด --nodelist=node05,node06,node07 ๋๋ --nodelist=node[05-07] ์ ๊ฐ์ด ์ธํ ํ๋ฉด ๋๋ค.
3.2 ๋จ๊ณ๋ณ ๋์ ๋ฐ ํน์ง
- srun์ด Slurm์์ ๊ฐ ๋ ธ๋์ ์์ ์ ์๋์ผ๋ก ๋ฐฐํฌํ๋ค.
- Slurm์ด MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE, NODEID ๋ฑ์ ์๋์ผ๋ก ์ค์ ํด ๊ฐ ๋ ธ๋๊ฐ ๋์ผํ๊ฒ ๊ณต์ ํ๋ค.
- torchrun์ ์ด ํ๊ฒฝ ๋ณ์๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ฐ ๋ ธ๋๊ฐ ํต์ ํ ๋ง์คํฐ๋ฅผ ์๋์ผ๋ก ์ฐพ๊ณ init_process_group()์ ํตํด ์ฐ๊ฒฐ์ ๋งบ๋๋ค.
- ๋ชจ๋ ๋ ธ๋์์ ๋์์ forward, backward๋ฅผ ์ํํ๊ณ , all-reduce๋ฅผ ํตํด ๊ทธ๋๋์ธํธ๋ฅผ ๋๊ธฐํํ๋ค.
๋ํ,
- ํ์ค ์ถ๋ ฅ๊ณผ ์๋ฌ ๋ก๊ทธ๋ฅผ slurm-<jobid>.out์ ์๋ ์ ์ฅํ๋ค.
- squeue, scontrol show job <jobid> ๊ฐ์ ๋ช ๋ น์ด๋ก ์ํ(๋๊ธฐ, ์คํ ์ค, ์ข ๋ฃ, ์คํจ)๋ฅผ ์ฝ๊ฒ ํ์ธํ ์ ์๋ค.
- Slurm Array Job, Checkpointing, Preemption ๊ฐ์ ๊ธฐ๋ฅ์ ํตํด ์ฅ๊ธฐ ํ์ต์์ ์๋ฒ๊ฐ ๋ค์ด๋๋๋ผ๋ ๋ณต๊ตฌํ๊ฑฐ๋ ์ฌ์์ํ ์ ์๋ ๊ธฐ๋ฅ์ ์ง์ํ๋ค.
- sbatch๋ฅผ ์ฌ๋ฌ ๋ฒ ์ ์ถํ๊ฑฐ๋, Array Job(--array=0-9)์ ์ฌ์ฉํด ํ์ดํผํ๋ผ๋ฏธํฐ ํ๋ ์คํ์ ํ๊บผ๋ฒ์ 10๊ฐ, 100๊ฐ๋ ๋์์ ๋๋ฆด ์ ์๋ค.
3.3 Slurm์์ ๋ฉํฐ ๋ ธ๋ PyTorch DDP ํ์ตํ๊ธฐ
์ฆ ์์ ์คํฌ๋ฆฝํธ๋ฅผ ์์ฑํ ๋ค, Slurm์ ์ ์ถํ๋ฉด ๋๋ค.
sbatch train_job.sh
์ด๋ ๊ฒ ํ๋ฉด Slurm์ด ์๋์ผ๋ก 3๋ ๋จธ์ ์ ๊ฐ๊ฐ torchrun์ ์ฌ๋ฐ๋ฅด๊ฒ ์คํ์์ผ ๋ฉํฐ ๋ ธ๋ DDP ํ์ต์ ์์ํ๋ค.
์ฆ, ์๋๋ ๊ฐ ๋จธ์ ์ ์ง์ ๋ค์ด๊ฐ์ torchrun์ ์คํํด์ผ ํ๋๋ฐ, Slurm์ด ์ด๋ฅผ ๋์ ํด์ ๊ฐ ๋ ธ๋์ ๋ค์ด๊ฐ torchrun์ ์๋์ผ๋ก ์คํํด ์ฃผ๊ธฐ ๋๋ฌธ์ ์ฌ์ฉ์๋ ๋ฉ์ธ ์๋ฒ(=Slurm ์ปจํธ๋กค ๋ ธ๋)์์ sbatch train_job.sh ํ ๋ฒ๋ง ์คํํ๋ฉด ๋๋ค.
๐ก ์ ๋ฆฌ
- Slurm + srun์ ์ฌ์ฉํ๋ฉด ์ฌ๋ฌ ๋จธ์ ์์ rank, master ip๋ฅผ ์ง์ ์ง์ ํ์ง ์์๋ ๋๋ค.
- Slurm์ด --node_rank์ ํ์ํ ํ๊ฒฝ ๋ณ์๋ฅผ ์๋์ผ๋ก ์ก์์ฃผ๋ฏ๋ก ๋งค์ฐ ํธ๋ฆฌํ๋ค.
- ๋ฐ๋ผ์ ๋๊ท๋ชจ GPU ํด๋ฌ์คํฐ ํ๊ฒฝ์์๋ Slurm๊ณผ PyTorch DDP๋ฅผ ๊ฒฐํฉํด ๋ฉํฐ ๋ ธ๋ ํ์ต์ ์์ ์ ์ด๊ณ ์ฝ๊ฒ ๊ตฌํํ ์ ์๋ค.
4. PyTorch๋ก ๋ชจ๋ธ ๋ณ๋ ฌํ ๊ตฌํํ๊ธฐ
๋ฅ๋ฌ๋ ๋ชจ๋ธ์ด ํ๋์ GPU์ ์ฌ๋ฆฌ๊ธฐ์๋ ๋๋ฌด ์ปค์ง๋ฉด, ๋ชจ๋ธ ๋ณ๋ ฌ(Model Parallelism) ์ด ํ์ํ๋ค. PyTorch๋ DataParallel, DistributedDataParallel์ฒ๋ผ ์๋์ผ๋ก ๋ชจ๋ธ์ ์ฌ๋ฌ GPU์ ๋ณต์ ํด ๋ฐ์ดํฐ๋ง ๋ถ์ฐํ๋ ๋ฐ์ดํฐ ๋ณ๋ ฌ ๋ฐฉ์๊ณผ ๋ฌ๋ฆฌ, ๋ชจ๋ธ ๋ณ๋ ฌ์ ์ฌ์ฉ์๊ฐ ์ง์ ๋ชจ๋ธ์ ๋ ์ด์ด๋ฅผ GPU์ ๋๋ ๋ฐฐ์นํด forward, backward๋ฅผ GPU ๊ฐ์ ์์ฐจ์ ์ผ๋ก ํ๋ ค๋ณด๋ด๋ ๋ฐฉ์์ด๋ค.
4.1 ๊ธฐ๋ณธ ์๋ ๋ชจ๋ธ ๋ณ๋ ฌํ (Manual Model Parallelism)
๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ ๋ชจ๋ธ ๋ณ๋ ฌํ๋ ๋ชจ๋ธ์ ์ผ๋ถ ๋ ์ด์ด๋ฅผ cuda:0์, ๋๋จธ์ง๋ฅผ cuda:1์ ์ฌ๋ ค์ forward ๊ณ์ฐ ์ ๋ฐ์ดํฐ๋ฅผ GPU ๊ฐ์ ์ ์กํ๋๋ก ๋ง๋๋ ๊ฒ์ด๋ค.
โ ๋จ๊ณ๋ณ ๊ตฌํ
- ๋ชจ๋ธ์ ๊ฐ ํํธ๋ฅผ ์ํ๋ GPU์ ์ฌ๋ฆฐ๋ค.
- forward ํจ์์์ to('cuda:x')๋ฅผ ํตํด ์ถ๋ ฅ์ ๋ค์ GPU๋ก ๋ณด๋ธ๋ค.
- backward๋ PyTorch์ Autograd๊ฐ ์๋์ผ๋ก GPU ๊ฐ ํต์ ์ ํตํด gradient๋ฅผ ๊ณ์ฐํ๋ค.
โ ์์ ์ฝ๋
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layer1 = nn.Linear(1024, 2048).to('cuda:0')
self.relu = nn.ReLU()
self.layer2 = nn.Linear(2048, 1024).to('cuda:1')
def forward(self, x):
x = x.to('cuda:0')
x = self.layer1(x)
x = self.relu(x)
x = x.to('cuda:1')
x = self.layer2(x)
return x
# ์ฌ์ฉ ์์
model = MyModel()
input_data = torch.randn(64, 1024).to('cuda:0')
output = model(input_data)
์ด ๊ตฌ์กฐ์์๋ layer1์ GPU0์์, layer2๋ GPU1์์ ์ํ๋๋ค. ๋ฐ์ดํฐ๋ forward ์ GPU0 → GPU1, backward ์ GPU1 → GPU0 ์์๋ก ํต์ ํ๋ฉฐ ์๋์ผ๋ก gradient๊ฐ ๊ณ์ฐ๋๋ค.
4.2 Pipeline Parallelism ์ฌ์ฉํ๊ธฐ (torch.distributed.pipeline.sync.Pipe)
PyTorch๋ Pipe ๋ชจ๋์ ํตํด ๋ชจ๋ธ์ stage ๋จ์๋ก ๋๋์ด ์ฌ๋ฌ GPU์ ๋ฐฐ์นํ๊ณ , ์ ๋ ฅ์ micro-batch๋ก ์๊ฒ ๋๋ ํ์ดํ๋ผ์ธ์ ์ฑ์ throughput์ ๊ทน๋ํํ๋ ๊ธฐ๋ฅ์ ์ ๊ณตํ๋ค.
โ ๋จ๊ณ๋ณ ๊ตฌํ
- ๋ชจ๋ธ์ torch.nn.Sequential๋ก ์ ์ํ๋ค.
- Pipe๋ฅผ ์ฌ์ฉํด ๊ฐ stage๋ฅผ ์ง์ ํ GPU์ ์๋์ผ๋ก ์ฌ๋ฆฌ๋๋ก ํ๋ค.
- chunks ํ๋ผ๋ฏธํฐ๋ฅผ ์ฌ์ฉํด ์ ๋ ฅ์ ์๊ฒ ์ชผ๊ฐ pipeline bubble์ ์ค์ธ๋ค.
โ ์์ ์ฝ๋
import torch
import torch.nn as nn
from torch.distributed.pipeline.sync import Pipe
# ๋ชจ๋ธ์ stage๋ก ๋๋ Sequential ์ ์
model = nn.Sequential(
nn.Linear(1024, 2048),
nn.ReLU(),
nn.Linear(2048, 1024)
)
# Pipe๋ฅผ ํตํด GPU 2๊ฐ์ ๋๋์ด ๋ฐฐ์น, ์
๋ ฅ์ 8๊ฐ๋ก ์ชผ๊ฐ pipeline ์ฒ๋ฆฌ
model = Pipe(model, devices=['cuda:0', 'cuda:1'], chunks=8)
# forward ์คํ
input_data = torch.randn(64, 1024).to('cuda:0')
output = model(input_data)
์ด๋ ๊ฒ ํ๋ฉด PyTorch๊ฐ ์๋์ผ๋ก stage1์ cuda:0, stage2๋ cuda:1์ ์ฌ๋ฆฌ๊ณ , ์ ๋ ฅ์ 8๊ฐ๋ก ๋๋ ํ์ดํ๋ผ์ธ์ ๋๋ฆฌ๋ฉฐ GPU๋ค์ด ์ฌ์ง ์๋๋ก ํ์ต์ ์ํํ๋ค.
4.3 ๋ชจ๋ธ ๋ณ๋ ฌํ๋ฅผ ์ธ ๋ ์ฃผ์ํ ์
- ๋ชจ๋ธ ๋ณ๋ ฌ์ ๋ฐ์ดํฐ ๋ณ๋ ฌ(DDP)์ฒ๋ผ ๊ฐ GPU๊ฐ ๊ฐ์ ๋ชจ๋ธ์ ๋ณต์ ํ๋ ๋ฐฉ์์ด ์๋๋ฏ๋ก, GPU ๊ฐ ํต์ ์ด ์ฆ์ ๋คํธ์ํฌ ๋ณ๋ชฉ์ด ๋ฐ์ํ๊ธฐ ์ฝ๋ค.
- ํ๋ผ๋ฏธํฐ๊ฐ GPU๋ง๋ค ๋๋์ด ์์ผ๋ฏ๋ก, ๋ชจ๋ธ ์ ์ฅ(Checkpoint) ์ ๊ฐ GPU์ ํ๋ผ๋ฏธํฐ๋ฅผ ๋ฐ๋ก ์ ์ฅํด์ผ ํ๋ค.
- Pipe๋ฅผ ์ธ ๊ฒฝ์ฐ chunks๋ฅผ ์ ์ ํ ์กฐ์ ํด bubble(๋น ์๊ฐ)์ ์ต์ํํ๋ ๊ฒ์ด ์ค์ํ๋ค.
PyTorch๋ ๋ชจ๋ธ ๋ณ๋ ฌ์ ์๋์ผ๋ก ๊ตฌํํ๊ฑฐ๋ Pipe๋ฅผ ์ด์ฉํด ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ์ ๊ตฌํํ ์ ์๋ค. ์ด๋ ํ๋์ GPU์ ๋ด๊ธฐ ์ด๋ ค์ด ์ด๋ํ ๋ชจ๋ธ์ ํ์ตํ ๋ ์ ์ฉํ๋ฉฐ, ํ์์ ๋ฐ๋ผ Megatron-LM, Deepspeed ๊ฐ์ ๋ ๋ฐ์ ๋ ํ๋ ์์ํฌ๋ก ๋์ด๊ฐ ์ ์๋ค.
'๐ ๏ธ Engineering > Distributed Training & Inference' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
| vLLM์ ํ์ฉํ Large-scale AI ๋ชจ๋ธ ๊ฐ์ํ | LLM Acceleration (0) | 2025.12.16 |
|---|---|
| DeepSpeed ์๋ฒฝ ์ดํดํ๊ธฐ! (1) | 2025.07.07 |
| PyTorch FSDP (Fully Sharded Data Parallel) ์๋ฒฝ ์ดํดํ๊ธฐ! (4) | 2025.07.06 |
| GPU ํด๋ฌ์คํฐ: SuperPOD์ Slurm์ ๊ฐ๋ ๊ณผ ํ์ฉ๋ฒ (1) | 2025.07.03 |