PyTorch ๋ถ„์‚ฐ ํ•™์Šต ๊ธฐ์ดˆ: ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌํ™”, ๋ชจ๋ธ ๋ณ‘๋ ฌํ™”, ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌํ™”

2025. 7. 3. 20:43ยท๐Ÿ› ๏ธ Engineering/Distributed Training & Inference
๋ฐ˜์‘ํ˜•

๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์ด ์ ์  ์ปค์ง€๊ณ  ๋ฐ์ดํ„ฐ๋„ ๋ฐฉ๋Œ€ํ•ด์ง€๋ฉด์„œ, ๋‹จ์ผ GPU๋‚˜ ์„œ๋ฒ„๋งŒ์œผ๋กœ๋Š” ํ•™์Šต ์†๋„๊ฐ€ ๋„ˆ๋ฌด ๋А๋ฆฌ๊ฑฐ๋‚˜ GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ๋ถ€์กฑํ•ด ํ•™์Šต์ด ๋ถˆ๊ฐ€๋Šฅํ•ด์ง„๋‹ค. ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ์—ฌ๋Ÿฌ GPU๋ฅผ ๋™์‹œ์— ํ™œ์šฉํ•ด ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ค๋Š” ๊ฒƒ์ด ๋ฐ”๋กœ ๋ถ„์‚ฐ ํ•™์Šต์ด๋‹ค.

1. ๋ถ„์‚ฐ ํ•™์Šต ์ข…๋ฅ˜

1.1 ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌํ™”(Data Parallelism)

[์ „์ฒด ๋ฐ์ดํ„ฐ] → [๋ถ„ํ• ๋œ ๋ฏธ๋‹ˆ๋ฐฐ์น˜1] → GPU0 (๋ชจ๋ธ ๋ณต์ œ)
            → [๋ถ„ํ• ๋œ ๋ฏธ๋‹ˆ๋ฐฐ์น˜2] → GPU1 (๋ชจ๋ธ ๋ณต์ œ)
            → [๋ถ„ํ• ๋œ ๋ฏธ๋‹ˆ๋ฐฐ์น˜3] → GPU2 (๋ชจ๋ธ ๋ณต์ œ)

[๊ฐ GPU] → forward & backward → all-reduce → ๋™๊ธฐํ™” → ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ

 

๊ฐ€์žฅ ๋ณดํŽธ์ ์œผ๋กœ ์‚ฌ์šฉ๋˜๋Š” ๋ฐฉ์‹์ด๋‹ค. ๋™์ผํ•œ ๋ชจ๋ธ์„ ์—ฌ๋Ÿฌ GPU์— ๋ณต์ œํ•˜๊ณ , ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ๋ฅผ GPU๋ณ„๋กœ ๋‚˜๋ˆ„์–ด ์ฒ˜๋ฆฌํ•œ๋‹ค. ๊ฐ GPU์—์„œ forward์™€ backward๋ฅผ ๊ณ„์‚ฐํ•œ ๋’ค, all-reduce ์—ฐ์‚ฐ์œผ๋กœ ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ํ‰๊ท ๋‚ด๊ณ  ๋™๊ธฐํ™”ํ•ด ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธํ•œ๋‹ค.

PyTorch์—๋Š” DataParallel(DP)๊ณผ DistributedDataParallel(DDP) ๋‘ ๊ฐ€์ง€๊ฐ€ ์žˆ๋‹ค. DP๋Š” ๋‹จ์ผ ๋จธ์‹ ์—์„œ ์—ฌ๋Ÿฌ GPU์— ๋ชจ๋ธ์„ ๋ณต์ œํ•ด ๋ฐ์ดํ„ฐ๋งŒ ๋‚˜๋ˆ  ๋„ฃ๋Š” ๋ฐฉ์‹์ด๋ฉฐ ๊ตฌํ˜„์ด ๊ฐ„๋‹จํ•˜์ง€๋งŒ, Python GIL(Global Interpreter Lock)๊ณผ single-process ๊ตฌ์กฐ๋กœ ์ธํ•ด ํ†ต์‹  ๋ณ‘๋ชฉ์ด ์‹ฌํ•˜๋‹ค. ๊ทธ๋ž˜์„œ ์‹ค์ œ๋กœ๋Š” ๊ฑฐ์˜ ํ•ญ์ƒ DDP๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. DDP๋Š” ๊ฐ GPU๋งˆ๋‹ค ๋ณ„๋„์˜ ํ”„๋กœ์„ธ์Šค๋ฅผ ์ƒ์„ฑํ•ด ํ†ต์‹  ๋ณ‘๋ชฉ์„ ์ค„์ด๊ณ  ํšจ์œจ์ ์œผ๋กœ all-reduce๋ฅผ ์ˆ˜ํ–‰ํ•ด ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ๋™๊ธฐํ™”ํ•œ๋‹ค.

  • ์˜ˆ์‹œ: PyTorch DDP, TensorFlow MirroredStrategy
  • ์žฅ์ : ๊ตฌํ˜„์ด ๋‹จ์ˆœํ•˜๊ณ  GPU๋ฅผ ๋Š˜๋ฆฌ๊ธฐ ์‰ฝ๋‹ค.
  • ๋‹จ์ : ๋ชจ๋ธ ์ „์ฒด๊ฐ€ ๊ฐ GPU์— ๋ณต์ œ๋˜๋ฏ€๋กœ, GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ๋ถ€์กฑํ•˜๋ฉด ๋ถˆ๊ฐ€๋Šฅํ•˜๋‹ค.

1.2 ๋ชจ๋ธ ๋ณ‘๋ ฌํ™”(Model Parallelism)

[์ž…๋ ฅ ๋ฐ์ดํ„ฐ] → GPU0 (Model Layer1-5) → GPU1 (Model Layer6-10) → GPU2 (Model Layer11-15)

forward → GPU ๊ฐ„ ์—ฐ์† ์ „๋‹ฌ → backward → GPU ๊ฐ„ ์—ฐ์† ์ „๋‹ฌ

 

๋ชจ๋ธ ์ž์ฒด๋ฅผ ์—ฌ๋Ÿฌ GPU์— ๋‚˜๋ˆ  ์ €์žฅํ•˜๊ณ  ์ˆœ์ฐจ์ ์œผ๋กœ forward ์—ฐ์‚ฐ์„ GPU๋ฅผ ๊ฑฐ์ณ ์ง„ํ–‰ํ•œ๋‹ค. GPT-3, T5 ๊ฐ™์€ ํ•˜๋‚˜์˜ GPU์— ์˜ฌ๋ผ๊ฐ€์ง€ ์•Š์„ ์ •๋„๋กœ ๊ฑฐ๋Œ€ํ•œ ๋ชจ๋ธ์—์„œ ์ฃผ๋กœ ์‚ฌ์šฉํ•œ๋‹ค. PyTorch์—์„œ๋Š” ํŠน์ • ๋ ˆ์ด์–ด๋ฅผ ์ˆ˜๋™์œผ๋กœ cuda:0, cuda:1์— ์˜ฌ๋ฆฌ๊ฑฐ๋‚˜, Megatron-LM๊ณผ ๊ฐ™์€ ํ”„๋ ˆ์ž„์›Œํฌ๊ฐ€ Tensor Parallelism์„ ํ†ตํ•ด ์ž๋™์œผ๋กœ ๋ ˆ์ด์–ด๋ฅผ ๋‚˜๋ˆ ์ค€๋‹ค.

  • ์˜ˆ์‹œ: Megatron-LM (Tensor Parallelism), manual PyTorch split
  • ์žฅ์ : ์ดˆ๋Œ€๊ทœ๋ชจ ๋ชจ๋ธ์„ ์—ฌ๋Ÿฌ GPU์— ๋‚˜๋ˆ ์„œ ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅํ•˜๋‹ค.
  • ๋‹จ์ : GPU ๊ฐ„ ํ†ต์‹ ๋Ÿ‰์ด ๋งŽ์•„ latency์™€ bandwidth ๋ณ‘๋ชฉ์ด ๋ฐœ์ƒํ•˜๊ธฐ ์‰ฝ๋‹ค. Layer๊ฐ„ ์ข…์†์œผ๋กœ ์ธํ•ด ๋ณ‘๋ ฌ์„ฑ์ด ์ œํ•œ๋œ๋‹ค.

1.3 ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌํ™”(Pipeline Parallelism)

[๋ฏธ๋‹ˆ๋ฐฐ์น˜1] → GPU0 (Stage1) → GPU1 (Stage2) → GPU2 (Stage3)
[๋ฏธ๋‹ˆ๋ฐฐ์น˜2] →               → GPU0 (Stage1) → GPU1 (Stage2) → GPU2 (Stage3)
[๋ฏธ๋‹ˆ๋ฐฐ์น˜3] →                               → GPU0 (Stage1) → GPU1 (Stage2) → GPU2 (Stage3)

→ ํŒŒ์ดํ”„๋ผ์ธ ์ฑ„์›Œ์„œ bubble ์ตœ์†Œํ™”

๋ชจ๋ธ์„ ์—ฌ๋Ÿฌ stage๋กœ ๋‚˜๋ˆ  GPU์— ๋ฐฐ์น˜ํ•˜๊ณ , ๋ฐ์ดํ„ฐ๋ฅผ ์—ฐ์†์ ์œผ๋กœ ํ˜๋ ค๋ณด๋‚ด ์ฒ˜๋ฆฌํ•˜๋Š” ๋ฐฉ์‹์ด๋‹ค. ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋ฅผ ๋” ์ž‘๊ฒŒ ์ชผ๊ฐœ pipeline์„ ์ฑ„์›Œ idle time(bubble)์„ ์ค„์ธ๋‹ค. ๊ตฌ์กฐ ์ž์ฒด๋Š” ๋ชจ๋ธ ๋ณ‘๋ ฌํ™”์ฒ˜๋Ÿผ layer๋ฅผ ์—ฌ๋Ÿฌ GPU์— ๋‚˜๋ˆ ๋†“์ง€๋งŒ, GPU0์—์„œ ๋ฏธ๋‹ˆ๋ฐฐ์น˜1์„ ์—ฐ์‚ฐํ•˜๊ณ  GPU1๋กœ ๋„˜๊ธด ๋’ค GPU0์€ ๋ฐ”๋กœ ๋ฏธ๋‹ˆ๋ฐฐ์น˜2๋ฅผ ์—ฐ์‚ฐํ•˜๊ธฐ ์‹œ์ž‘ํ•ด GPU๋“ค์ด ์‰ฌ์ง€ ์•Š๋„๋ก ๋งŒ๋“œ๋Š” ๋ฐฉ์‹์ด๋‹ค. Deepspeed Pipeline์ด๋‚˜ PyTorch torch.distributed.pipeline.sync.Pipe๊ฐ€ ์ด๋ฅผ ์ง€์›ํ•œ๋‹ค.

  • ์˜ˆ์‹œ: Deepspeed Pipeline, PyTorch Pipe
  • ์žฅ์ : layer๋ฅผ stage๋ณ„๋กœ ๋ถ„๋ฆฌํ•ด ๋ฉ”๋ชจ๋ฆฌ ๋ถ€๋‹ด์„ ๋” ์„ธ๋ฐ€ํžˆ ๋ถ„์‚ฐํ•  ์ˆ˜ ์žˆ๋‹ค.
  • ๋‹จ์ : ์ˆœ์ฐจ ์ฒ˜๋ฆฌ ๊ตฌ์กฐ๋ผ์„œ latency๊ฐ€ ์ฆ๊ฐ€ํ•˜๋ฉฐ, bubble์ด ๋ฐœ์ƒํ•˜์ง€ ์•Š๋„๋ก carefulํ•˜๊ฒŒ microbatch๋ฅผ ์กฐ์ ˆํ•ด์•ผ ํ•œ๋‹ค.

์–ด๋–ค ๋ถ„์‚ฐ ํ•™์Šต์„ ์–ธ์ œ ์“ฐ๋‚˜?

  • ์ž‘์€~์ค‘๊ฐ„ ํฌ๊ธฐ ๋ชจ๋ธ (์ˆ˜์ฒœ๋งŒ~์ˆ˜์–ต ํŒŒ๋ผ๋ฏธํ„ฐ): ๋‹จ์ˆœํžˆ DDP๋งŒ์œผ๋กœ ์ถฉ๋ถ„. GPU๋ฅผ ์—ฌ๋Ÿฌ ๊ฐœ ์“ธ์ˆ˜๋ก ํ•™์Šต ์†๋„๊ฐ€ ์„ ํ˜•์ ์œผ๋กœ ์ฆ๊ฐ€ํ•œ๋‹ค.
  • ํ•˜๋‚˜์˜ GPU ๋ฉ”๋ชจ๋ฆฌ์— ์˜ฌ๋ผ๊ฐ€์ง€ ์•Š๋Š” ๋ชจ๋ธ (์ˆ˜์‹ญ์–ต ํŒŒ๋ผ๋ฏธํ„ฐ): ๋ชจ๋ธ ๋ณ‘๋ ฌ ๋˜๋Š” ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌ์„ ์ ์šฉํ•ด GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๋ถ„์‚ฐ.
  • ์ˆ˜์‹ญ์–ต~์ˆ˜๋ฐฑ์–ต ํŒŒ๋ผ๋ฏธํ„ฐ ์ด์ƒ ์ดˆ๊ฑฐ๋Œ€ ๋ชจ๋ธ: DDP + ๋ชจ๋ธ ๋ณ‘๋ ฌ + ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌ์„ ํ•จ๊ป˜ ์กฐํ•ฉํ•œ๋‹ค. Megatron-LM, Deepspeed ZeRO-Infinity ๊ฐ™์€ ํ”„๋ ˆ์ž„์›Œํฌ๊ฐ€ ์ด๋ฅผ ์ž๋™ํ™”ํ•ด์ค€๋‹ค.

๋”ฐ๋ผ์„œ ๋ชจ๋ธ ํฌ๊ธฐ, GPU ๋ฉ”๋ชจ๋ฆฌ, ๋„คํŠธ์›Œํฌ ๋Œ€์—ญํญ์„ ๊ณ ๋ คํ•ด ์ ์ ˆํ•œ ๋ถ„์‚ฐ ๋ฐฉ์‹ ๋˜๋Š” ์กฐํ•ฉ์„ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.

 

2. PyTorch DistributedDataParallel(DDP) ์‚ฌ์šฉํ•˜๊ธฐ

PyTorch DistributedDataParallel(DDP)์€ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ(Data Parallelism)์„ ์œ„ํ•œ ๊ฐ€์žฅ ํ‘œ์ค€์ ์ธ ๋ฐฉ๋ฒ•์ด๋‹ค. ์—ฌ๋Ÿฌ GPU์— ๋™์ผํ•œ ๋ชจ๋ธ์„ ์˜ฌ๋ฆฌ๊ณ , ๊ฐ GPU๊ฐ€ ์„œ๋กœ ๋‹ค๋ฅธ ๋ฐ์ดํ„ฐ ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋ฅผ ํ•™์Šตํ•œ ๋’ค, all-reduce๋ฅผ ํ†ตํ•ด ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ๋™๊ธฐํ™”ํ•ด ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธํ•œ๋‹ค.

๋‹จ์ผ ๋จธ์‹ (์—ฌ๋Ÿฌ GPU)์—์„œ๋„ ์“ธ ์ˆ˜ ์žˆ๊ณ , ๋ฉ€ํ‹ฐ ๋…ธ๋“œ(์—ฌ๋Ÿฌ ์„œ๋ฒ„)์—์„œ๋„ ๋™์ผํ•œ ์ฝ”๋“œ ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

all-reduce๋Š” ๊ฐ GPU์—์„œ ๊ณ„์‚ฐํ•œ ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ์„œ๋กœ ์ฃผ๊ณ ๋ฐ›์•„ ํ‰๊ท ์„ ๊ตฌํ•˜๊ณ  ๋™๊ธฐํ™”ํ•˜๋Š” ํ†ต์‹  ์—ฐ์‚ฐ์ด๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ๋ชจ๋“  GPU๊ฐ€ ๋™์ผํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์œ ์ง€ํ•˜๋ฉด์„œ ํ•™์Šต์„ ์ง„ํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค. ์ฃผ๋กœ NCCL๋กœ ๊ณ ์† ์ˆ˜ํ–‰๋˜๋ฉฐ, DDP์—์„œ ์ž๋™์œผ๋กœ ์ˆ˜ํ–‰๋œ๋‹ค.

2.1 ์‚ฌ์ „ ์ค€๋น„

  • ๋ชจ๋“  ๋จธ์‹ (๋…ธ๋“œ)์— ๋™์ผํ•œ ์ฝ”๋“œ, ๋ฐ์ดํ„ฐ, ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ์–ด์•ผ ํ•œ๋‹ค.
  • CUDA์™€ NCCL์ด ์„ค์น˜๋˜์–ด ์žˆ์–ด์•ผ ํ•œ๋‹ค. (PyTorch์˜ backend๋กœ ์ฃผ๋กœ nccl์„ ์‚ฌ์šฉ)
  • ๋‹จ์ผ ๋จธ์‹ ์ด๋ฉด CUDA_VISIBLE_DEVICES๋กœ GPU๋ฅผ ๊ด€๋ฆฌํ•˜๋ฉด ๋˜๊ณ , ๋ฉ€ํ‹ฐ ๋…ธ๋“œ๋ผ๋ฉด MASTER_ADDR, MASTER_PORT ํ™˜๊ฒฝ๋ณ€์ˆ˜๋ฅผ ์„ค์ •ํ•ด์•ผ ํ•œ๋‹ค.

๋ฉ€ํ‹ฐ ๋…ธ๋“œ ํ™˜๊ฒฝ์—์„œ๋Š” ๋ณดํ†ต MASTER_ADDR์„ rank=0 ๋จธ์‹ ์˜ IP๋กœ ์„ค์ •ํ•œ๋‹ค.

export MASTER_ADDR=192.168.1.10
export MASTER_PORT=12355

 

2.2 DDP ๊ธฐ๋ณธ ์ฝ”๋“œ ๊ตฌ์กฐ

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)
    device = torch.device(f'cuda:{rank}')
    
    model = MyModel().to(device)
    ddp_model = DDP(model, device_ids=[rank])

    dataset = MyDataset()
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
    dataloader = DataLoader(dataset, batch_size=64, sampler=sampler)

    optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-4)
    criterion = torch.nn.CrossEntropyLoss()

    for epoch in range(10):
        sampler.set_epoch(epoch)
        for batch in dataloader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = ddp_model(inputs)
            loss = criterion(outputs, targets)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    cleanup()

์ด ์Šคํฌ๋ฆฝํŠธ๋Š” GPU 0, GPU 1, GPU 2... ๊ฐ๊ฐ์ด ๋…๋ฆฝ์ ์ธ Python ํ”„๋กœ์„ธ์Šค๋กœ ์‹คํ–‰๋˜์–ด, ๋™์ผํ•œ ๋ชจ๋ธ์„ ๊ฐ GPU์—์„œ ํ•™์Šตํ•˜๊ณ  all-reduce๋ฅผ ํ†ตํ•ด ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ๋™๊ธฐํ™”ํ•œ๋‹ค.

 

2.3 ์‹ค์ œ ์‹คํ–‰ํ•˜๊ธฐ

โœ… ๋‹จ์ผ ๋จธ์‹ 

import torch.multiprocessing as mp

def main():
    world_size = 4
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

if __name__ == '__main__':
    main()

๋‹จ์ผ ๋จธ์‹ ์—์„œ GPU๊ฐ€ 4๊ฐœ๋ผ๋ฉด torch.multiprocessing.spawn์„ ์‚ฌ์šฉํ•ด ์œ„์™€ ๊ฐ™์ด ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค.

 

โœ… ๋ฉ€ํ‹ฐ ๋จธ์‹ 

๋ฉ€ํ‹ฐ ๋จธ์‹ ์—์„œ๋Š” ๋ชจ๋“  ๋จธ์‹ ์—์„œ ๋™์ผํ•˜๊ฒŒ export์™€ torchrun์„ ์‹คํ–‰ํ•ด์•ผ ํ•œ๋‹ค. ์ฆ‰, master(๋ณดํ†ต rank=0) ๋จธ์‹ ์—์„œ๋งŒ ํ™˜๊ฒฝ ๋ณ€์ˆ˜๋ฅผ ์„ค์ •ํ•˜๊ฑฐ๋‚˜ torchrun์„ ๋Œ๋ฆฌ๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ, ๋ชจ๋“  ๋จธ์‹ ์—์„œ ํ™˜๊ฒฝ ๋ณ€์ˆ˜๋ฅผ ๋™์ผํ•˜๊ฒŒ ์„ค์ •ํ•˜๊ณ  torchrun์„ ๊ฐ๊ฐ ์‹คํ–‰ํ•ด์•ผ ํ•œ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด ๋จธ์‹ ์ด 3๋Œ€ ์žˆ๊ณ , ๊ฐ๊ฐ GPU๊ฐ€ 4๊ฐœ์”ฉ ์žˆ๋Š” ๊ฒฝ์šฐ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ•œ๋‹ค.

export MASTER_ADDR=192.168.1.10
export MASTER_PORT=12355

# ๋จธ์‹ 1 (rank=0)
torchrun --nnodes=3 --nproc_per_node=4 --node_rank=0 train.py

# ๋จธ์‹ 2 (rank=1)
torchrun --nnodes=3 --nproc_per_node=4 --node_rank=1 train.py

# ๋จธ์‹ 3 (rank=2)
torchrun --nnodes=3 --nproc_per_node=4 --node_rank=2 train.py
  • ๋ชจ๋“  ๋จธ์‹ ์—์„œ ๋™์ผํ•˜๊ฒŒ export ํ•œ๋‹ค. MASTER_ADDR์€ master ์—ญํ• ์„ ํ•  rank=0 ๋จธ์‹ ์˜ IP๋ฅผ ์ ๋Š”๋‹ค.
    • “ํ†ต์‹ ์„ ์œ„ํ•œ ๊ธฐ์ค€์ด ๋˜๋Š” ๋งˆ์Šคํ„ฐ IP(=rank=0 ๋จธ์‹ )”๋ฅผ ์•Œ๋ ค์ฃผ๋Š” ๊ฒƒ
  • ๊ทธ๋ฆฌ๊ณ  ๊ฐ ๋จธ์‹ ์—์„œ ์ž์‹ ์˜ node_rank์— ๋งž๊ฒŒ torchrun์„ ์‹คํ–‰ํ•œ๋‹ค.
  • ์ฆ‰, ๊ฐ ๋จธ์‹ ์ด ๋ชจ๋‘ ๋™์‹œ์— ์ด ๋ช…๋ น์„ ์‹คํ–‰ํ•ด์•ผ DDP๊ฐ€ ์ •์ƒ์ ์œผ๋กœ ํ†ต์‹ ์„ ์‹œ์ž‘ํ•œ๋‹ค.
  • ๋งˆ์Šคํ„ฐ ๋จธ์‹ ๋งŒ ์‹คํ–‰ํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ฉฐ, ๋ชจ๋“  ๋จธ์‹ ์ด ์ž์‹ ์˜ node_rank๋ฅผ ์ง€์ •ํ•ด ๋™์ผํ•œ train.py๋ฅผ ์‹คํ–‰ํ•ด์•ผ ํ•œ๋‹ค.

2.4 ์ฃผ์š” ๊ฐœ๋…๊ณผ ์ฃผ์˜์‚ฌํ•ญ

 

  • rank ๋Š” ์ „์ฒด ํ”„๋กœ์„ธ์Šค์—์„œ ๊ณ ์œ  ๋ฒˆํ˜ธ๋ฅผ ๋งํ•œ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด 8 GPU๋ฉด rank=0~7๊นŒ์ง€ ์žˆ๋‹ค.
  • world_size ๋Š” ์ „์ฒด GPU(=์ „์ฒด ํ”„๋กœ์„ธ์Šค) ์ˆ˜๋ฅผ ์˜๋ฏธํ•œ๋‹ค.
  • DistributedSampler ๋Š” ๊ฐ ํ”„๋กœ์„ธ์Šค๊ฐ€ ๊ฐ™์€ ๋ฐ์ดํ„ฐ์…‹์„ ์„œ๋กœ ๋‹ค๋ฅธ ์ˆœ์„œ/๋ฒ”์œ„๋กœ ์ฝ๊ฒŒ ํ•œ๋‹ค. epoch๋งˆ๋‹ค set_epoch(epoch)๋ฅผ ํ˜ธ์ถœํ•ด์•ผ ๋ฐ์ดํ„ฐ ์…”ํ”Œ๋ง์ด ์ž˜ ๋™์ž‘ํ•œ๋‹ค.
  • DDP๋Š” DataParallel(DP)๊ณผ ๋‹ฌ๋ฆฌ Python GIL ๋ณ‘๋ชฉ ์—†์ด ๊ฐ ํ”„๋กœ์„ธ์Šค๊ฐ€ GPU ํ•˜๋‚˜์”ฉ์„ ์ „๋‹ดํ•˜๋ฏ€๋กœ ํ›จ์”ฌ ํšจ์œจ์ ์ด๋‹ค.

PyTorch DDP๋Š” ๋‹จ์ผ ๋จธ์‹ , ๋ฉ€ํ‹ฐ ๋จธ์‹  ๋ชจ๋‘์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ฐ•๋ ฅํ•˜๊ณ  ํ‘œ์ค€์ ์ธ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ํ”„๋ ˆ์ž„์›Œํฌ์ด๋‹ค. setup → DDP๋กœ ๋ชจ๋ธ ๊ฐ์‹ธ๊ธฐ → DistributedSampler → backward์—์„œ all-reduce ์ˆœ์œผ๋กœ ๋™์ž‘ํ•œ๋‹ค. ์ด๋ฅผ ํ†ตํ•ด ๋™์ผํ•œ ๋ชจ๋ธ์„ GPU๋งˆ๋‹ค ๋…๋ฆฝ์ ์œผ๋กœ ํ•™์Šตํ•˜๊ณ , ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ์ž๋™์œผ๋กœ ํ†ต์‹ ·๋™๊ธฐํ™”ํ•ด ํšจ์œจ์ ์œผ๋กœ ๋Œ€๊ทœ๋ชจ ๋ชจ๋ธ์„ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋‹ค.

 

 

3. Slurm์—์„œ ๋ฉ€ํ‹ฐ ๋จธ์‹  PyTorch DDP ์‹คํ–‰ํ•˜๊ธฐ

๋ฉ€ํ‹ฐ ๋…ธ๋“œ ํ™˜๊ฒฝ์—์„œ Slurm์„ ์‚ฌ์šฉํ•˜๋ฉด srun์œผ๋กœ ์—ฌ๋Ÿฌ ๋จธ์‹ ์—์„œ ๋™์‹œ์— torchrun์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์–ด ํŽธ๋ฆฌํ•˜๋‹ค. Slurm์ด ์ž๋™์œผ๋กœ MASTER_ADDR, RANK, WORLD_SIZE ๊ฐ™์€ ํ™˜๊ฒฝ ๋ณ€์ˆ˜๋ฅผ ์„ธํŒ…ํ•ด ์ฃผ๊ธฐ ๋•Œ๋ฌธ์— ๋”ฐ๋กœ export ํ•  ํ•„์š”๋„ ์—†๋‹ค.

 

ํ•˜์ง€๋งŒ, ์Šฌ๋Ÿผ์€ GPU ์„œ๋ฒ„๋ฅผ ํด๋Ÿฌ์Šคํ„ฐ๋กœ ๊ตฌ์„ฑํ•ด ๋†“์€ ํ™˜๊ฒฝ์—์„œ๋งŒ ์“ธ ์ˆ˜ ์žˆ๋Š” ๋„๊ตฌ๋ผ ๊ทธ๋ƒฅ EC2 ์—ฌ๋Ÿฌ ๋Œ€๋ฅผ ๋„์šด๋‹ค๊ณ  ๊ณง์žฅ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒŒ ์•„๋‹ˆ๊ณ , ๋”ฐ๋กœ Slurm ์ƒํƒœ๊ณ„(=slurmd, slurmctld, config)๋ฅผ ์„ค์น˜ + ์„ค์ •ํ•ด์•ผ ํ•œ๋‹ค.

 

3.1 Slurm ์Šคํฌ๋ฆฝํŠธ ์˜ˆ์ œ (multi-node DDP)

์•„๋ž˜๋Š” Slurm job ์Šคํฌ๋ฆฝํŠธ(train_job.sh) ์˜ˆ์ œ์ด๋‹ค.

#!/bin/bash
#SBATCH --job-name=ddp_train
#SBATCH --nodes=3              # ์‚ฌ์šฉํ•  ๋…ธ๋“œ ์ˆ˜
#SBATCH --ntasks-per-node=4    # ๋…ธ๋“œ๋‹น GPU ๊ฐœ์ˆ˜
#SBATCH --gres=gpu:4           # ๋…ธ๋“œ๋‹น GPU ํ• ๋‹น
#SBATCH --cpus-per-task=4
#SBATCH --time=24:00:00
#SBATCH --partition=gpu

module load cuda/12.1
module load python/3.11

srun torchrun \
    --nnodes=$SLURM_JOB_NUM_NODES \
    --nproc_per_node=$SLURM_NTASKS_PER_NODE \
    --node_rank=$SLURM_NODEID \
    train.py

#SBATCH --nodes=3, #SBATCH --ntasks-per-node=4 ์ด๋ผ๊ณ  ์„ค์ •ํ–ˆ์œผ๋ฏ€๋กœ

์Šฌ๋Ÿผ ๋ณ€์ˆ˜ ์‹ค์ œ ์˜ˆ์‹œ ๊ฐ’
$SLURM_JOB_NUM_NODES 3
$SLURM_NTASKS_PER_NODE 4
$SLURM_NODEID ๊ฐ ๋…ธ๋“œ์—์„œ 0,1,2

์Šฌ๋Ÿผ ๋ณ€์ˆ˜๋Š” ์œ„์™€ ๊ฐ™์ด ์„ธํŒ…๋˜๊ณ , ๊ทธ๋Ÿฌ๋ฉด srun์ด ๊ฐ ๋…ธ๋“œ์—์„œ ์ด๋ ‡๊ฒŒ ์‹คํ–‰ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

# ๋…ธ๋“œ1์—์„œ
torchrun --nnodes=3 --nproc_per_node=4 --node_rank=0 train.py

# ๋…ธ๋“œ2์—์„œ
torchrun --nnodes=3 --nproc_per_node=4 --node_rank=1 train.py

# ๋…ธ๋“œ3์—์„œ
torchrun --nnodes=3 --nproc_per_node=4 --node_rank=2 train.py
ํŠน์ • ๋…ธ๋“œ๋ฅผ ์‚ฌ์šฉํ•ด์„œ job์„ ์‹คํ–‰ํ•˜๊ณ  ์‹ถ์œผ๋ฉด --nodelist=node05,node06,node07 ๋˜๋Š” --nodelist=node[05-07] ์™€ ๊ฐ™์ด ์„ธํŒ…ํ•˜๋ฉด ๋œ๋‹ค.

 

3.2 ๋‹จ๊ณ„๋ณ„ ๋™์ž‘ ๋ฐ ํŠน์ง•

  1. srun์ด Slurm์—์„œ ๊ฐ ๋…ธ๋“œ์— ์ž‘์—…์„ ์ž๋™์œผ๋กœ ๋ฐฐํฌํ•œ๋‹ค.
  2. Slurm์ด MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE, NODEID ๋“ฑ์„ ์ž๋™์œผ๋กœ ์„ค์ •ํ•ด ๊ฐ ๋…ธ๋“œ๊ฐ€ ๋™์ผํ•˜๊ฒŒ ๊ณต์œ ํ•œ๋‹ค.
  3. torchrun์€ ์ด ํ™˜๊ฒฝ ๋ณ€์ˆ˜๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๊ฐ ๋…ธ๋“œ๊ฐ€ ํ†ต์‹ ํ•  ๋งˆ์Šคํ„ฐ๋ฅผ ์ž๋™์œผ๋กœ ์ฐพ๊ณ  init_process_group()์„ ํ†ตํ•ด ์—ฐ๊ฒฐ์„ ๋งบ๋Š”๋‹ค.
  4. ๋ชจ๋“  ๋…ธ๋“œ์—์„œ ๋™์‹œ์— forward, backward๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ณ , all-reduce๋ฅผ ํ†ตํ•ด ๊ทธ๋ž˜๋””์–ธํŠธ๋ฅผ ๋™๊ธฐํ™”ํ•œ๋‹ค.

๋˜ํ•œ,

  • ํ‘œ์ค€ ์ถœ๋ ฅ๊ณผ ์—๋Ÿฌ ๋กœ๊ทธ๋ฅผ slurm-<jobid>.out์— ์ž๋™ ์ €์žฅํ•œ๋‹ค.
  • squeue, scontrol show job <jobid> ๊ฐ™์€ ๋ช…๋ น์–ด๋กœ ์ƒํƒœ(๋Œ€๊ธฐ, ์‹คํ–‰ ์ค‘, ์ข…๋ฃŒ, ์‹คํŒจ)๋ฅผ ์‰ฝ๊ฒŒ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.
  • Slurm Array Job, Checkpointing, Preemption ๊ฐ™์€ ๊ธฐ๋Šฅ์„ ํ†ตํ•ด ์žฅ๊ธฐ ํ•™์Šต์—์„œ ์„œ๋ฒ„๊ฐ€ ๋‹ค์šด๋˜๋”๋ผ๋„ ๋ณต๊ตฌํ•˜๊ฑฐ๋‚˜ ์žฌ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ๋Š” ๊ธฐ๋Šฅ์„ ์ง€์›ํ•œ๋‹ค.
  • sbatch๋ฅผ ์—ฌ๋Ÿฌ ๋ฒˆ ์ œ์ถœํ•˜๊ฑฐ๋‚˜, Array Job(--array=0-9)์„ ์‚ฌ์šฉํ•ด ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ํŠœ๋‹ ์‹คํ—˜์„ ํ•œ๊บผ๋ฒˆ์— 10๊ฐœ, 100๊ฐœ๋„ ๋™์‹œ์— ๋Œ๋ฆด ์ˆ˜ ์žˆ๋‹ค.

 

3.3 Slurm์—์„œ ๋ฉ€ํ‹ฐ ๋…ธ๋“œ PyTorch DDP ํ•™์Šตํ•˜๊ธฐ

์ฆ‰ ์œ„์˜ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์ž‘์„ฑํ•œ ๋’ค, Slurm์— ์ œ์ถœํ•˜๋ฉด ๋œ๋‹ค.

sbatch train_job.sh

์ด๋ ‡๊ฒŒ ํ•˜๋ฉด Slurm์ด ์ž๋™์œผ๋กœ 3๋Œ€ ๋จธ์‹ ์— ๊ฐ๊ฐ torchrun์„ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์‹คํ–‰์‹œ์ผœ ๋ฉ€ํ‹ฐ ๋…ธ๋“œ DDP ํ•™์Šต์„ ์‹œ์ž‘ํ•œ๋‹ค.

์ฆ‰, ์›๋ž˜๋Š” ๊ฐ ๋จธ์‹ ์— ์ง์ ‘ ๋“ค์–ด๊ฐ€์„œ torchrun์„ ์‹คํ–‰ํ•ด์•ผ ํ•˜๋Š”๋ฐ, Slurm์ด ์ด๋ฅผ ๋Œ€์‹ ํ•ด์„œ ๊ฐ ๋…ธ๋“œ์— ๋“ค์–ด๊ฐ€ torchrun์„ ์ž๋™์œผ๋กœ ์‹คํ–‰ํ•ด ์ฃผ๊ธฐ ๋•Œ๋ฌธ์— ์‚ฌ์šฉ์ž๋Š” ๋ฉ”์ธ ์„œ๋ฒ„(=Slurm ์ปจํŠธ๋กค ๋…ธ๋“œ)์—์„œ sbatch train_job.sh ํ•œ ๋ฒˆ๋งŒ ์‹คํ–‰ํ•˜๋ฉด ๋œ๋‹ค.

 

๐Ÿ’ก ์ •๋ฆฌ

  • Slurm + srun์„ ์‚ฌ์šฉํ•˜๋ฉด ์—ฌ๋Ÿฌ ๋จธ์‹ ์—์„œ rank, master ip๋ฅผ ์ง์ ‘ ์ง€์ •ํ•˜์ง€ ์•Š์•„๋„ ๋œ๋‹ค.
  • Slurm์ด --node_rank์— ํ•„์š”ํ•œ ํ™˜๊ฒฝ ๋ณ€์ˆ˜๋ฅผ ์ž๋™์œผ๋กœ ์žก์•„์ฃผ๋ฏ€๋กœ ๋งค์šฐ ํŽธ๋ฆฌํ•˜๋‹ค.
  • ๋”ฐ๋ผ์„œ ๋Œ€๊ทœ๋ชจ GPU ํด๋Ÿฌ์Šคํ„ฐ ํ™˜๊ฒฝ์—์„œ๋Š” Slurm๊ณผ PyTorch DDP๋ฅผ ๊ฒฐํ•ฉํ•ด ๋ฉ€ํ‹ฐ ๋…ธ๋“œ ํ•™์Šต์„ ์•ˆ์ •์ ์ด๊ณ  ์‰ฝ๊ฒŒ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ๋‹ค.

 

4. PyTorch๋กœ ๋ชจ๋ธ ๋ณ‘๋ ฌํ™” ๊ตฌํ˜„ํ•˜๊ธฐ

๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์ด ํ•˜๋‚˜์˜ GPU์— ์˜ฌ๋ฆฌ๊ธฐ์—๋Š” ๋„ˆ๋ฌด ์ปค์ง€๋ฉด, ๋ชจ๋ธ ๋ณ‘๋ ฌ(Model Parallelism) ์ด ํ•„์š”ํ•˜๋‹ค. PyTorch๋Š” DataParallel, DistributedDataParallel์ฒ˜๋Ÿผ ์ž๋™์œผ๋กœ ๋ชจ๋ธ์„ ์—ฌ๋Ÿฌ GPU์— ๋ณต์ œํ•ด ๋ฐ์ดํ„ฐ๋งŒ ๋ถ„์‚ฐํ•˜๋Š” ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ๋ฐฉ์‹๊ณผ ๋‹ฌ๋ฆฌ, ๋ชจ๋ธ ๋ณ‘๋ ฌ์€ ์‚ฌ์šฉ์ž๊ฐ€ ์ง์ ‘ ๋ชจ๋ธ์˜ ๋ ˆ์ด์–ด๋ฅผ GPU์— ๋‚˜๋ˆ  ๋ฐฐ์น˜ํ•ด forward, backward๋ฅผ GPU ๊ฐ„์— ์ˆœ์ฐจ์ ์œผ๋กœ ํ˜๋ ค๋ณด๋‚ด๋Š” ๋ฐฉ์‹์ด๋‹ค.

4.1 ๊ธฐ๋ณธ ์ˆ˜๋™ ๋ชจ๋ธ ๋ณ‘๋ ฌํ™” (Manual Model Parallelism)

๊ฐ€์žฅ ๊ธฐ๋ณธ์ ์ธ ๋ชจ๋ธ ๋ณ‘๋ ฌํ™”๋Š” ๋ชจ๋ธ์˜ ์ผ๋ถ€ ๋ ˆ์ด์–ด๋ฅผ cuda:0์—, ๋‚˜๋จธ์ง€๋ฅผ cuda:1์— ์˜ฌ๋ ค์„œ forward ๊ณ„์‚ฐ ์‹œ ๋ฐ์ดํ„ฐ๋ฅผ GPU ๊ฐ„์— ์ „์†กํ•˜๋„๋ก ๋งŒ๋“œ๋Š” ๊ฒƒ์ด๋‹ค.

 

โœ… ๋‹จ๊ณ„๋ณ„ ๊ตฌํ˜„

  1. ๋ชจ๋ธ์˜ ๊ฐ ํŒŒํŠธ๋ฅผ ์›ํ•˜๋Š” GPU์— ์˜ฌ๋ฆฐ๋‹ค.
  2. forward ํ•จ์ˆ˜์—์„œ to('cuda:x')๋ฅผ ํ†ตํ•ด ์ถœ๋ ฅ์„ ๋‹ค์Œ GPU๋กœ ๋ณด๋‚ธ๋‹ค.
  3. backward๋Š” PyTorch์˜ Autograd๊ฐ€ ์ž๋™์œผ๋กœ GPU ๊ฐ„ ํ†ต์‹ ์„ ํ†ตํ•ด gradient๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค.

โœ… ์˜ˆ์‹œ ์ฝ”๋“œ

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layer1 = nn.Linear(1024, 2048).to('cuda:0')
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(2048, 1024).to('cuda:1')

    def forward(self, x):
        x = x.to('cuda:0')
        x = self.layer1(x)
        x = self.relu(x)
        x = x.to('cuda:1')
        x = self.layer2(x)
        return x

# ์‚ฌ์šฉ ์˜ˆ์‹œ
model = MyModel()
input_data = torch.randn(64, 1024).to('cuda:0')
output = model(input_data)

์ด ๊ตฌ์กฐ์—์„œ๋Š” layer1์€ GPU0์—์„œ, layer2๋Š” GPU1์—์„œ ์ˆ˜ํ–‰๋œ๋‹ค. ๋ฐ์ดํ„ฐ๋Š” forward ์‹œ GPU0 → GPU1, backward ์‹œ GPU1 → GPU0 ์ˆœ์„œ๋กœ ํ†ต์‹ ํ•˜๋ฉฐ ์ž๋™์œผ๋กœ gradient๊ฐ€ ๊ณ„์‚ฐ๋œ๋‹ค.

 

4.2 Pipeline Parallelism ์‚ฌ์šฉํ•˜๊ธฐ (torch.distributed.pipeline.sync.Pipe)

PyTorch๋Š” Pipe ๋ชจ๋“ˆ์„ ํ†ตํ•ด ๋ชจ๋ธ์„ stage ๋‹จ์œ„๋กœ ๋‚˜๋ˆ„์–ด ์—ฌ๋Ÿฌ GPU์— ๋ฐฐ์น˜ํ•˜๊ณ , ์ž…๋ ฅ์„ micro-batch๋กœ ์ž˜๊ฒŒ ๋‚˜๋ˆ  ํŒŒ์ดํ”„๋ผ์ธ์„ ์ฑ„์›Œ throughput์„ ๊ทน๋Œ€ํ™”ํ•˜๋Š” ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•œ๋‹ค.

 

โœ… ๋‹จ๊ณ„๋ณ„ ๊ตฌํ˜„

  1. ๋ชจ๋ธ์„ torch.nn.Sequential๋กœ ์ •์˜ํ•œ๋‹ค.
  2. Pipe๋ฅผ ์‚ฌ์šฉํ•ด ๊ฐ stage๋ฅผ ์ง€์ •ํ•œ GPU์— ์ž๋™์œผ๋กœ ์˜ฌ๋ฆฌ๋„๋ก ํ•œ๋‹ค.
  3. chunks ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์‚ฌ์šฉํ•ด ์ž…๋ ฅ์„ ์ž˜๊ฒŒ ์ชผ๊ฐœ pipeline bubble์„ ์ค„์ธ๋‹ค.

โœ… ์˜ˆ์‹œ ์ฝ”๋“œ

import torch
import torch.nn as nn
from torch.distributed.pipeline.sync import Pipe

# ๋ชจ๋ธ์„ stage๋กœ ๋‚˜๋ˆŒ Sequential ์ •์˜
model = nn.Sequential(
    nn.Linear(1024, 2048),
    nn.ReLU(),
    nn.Linear(2048, 1024)
)

# Pipe๋ฅผ ํ†ตํ•ด GPU 2๊ฐœ์— ๋‚˜๋ˆ„์–ด ๋ฐฐ์น˜, ์ž…๋ ฅ์„ 8๊ฐœ๋กœ ์ชผ๊ฐœ pipeline ์ฒ˜๋ฆฌ
model = Pipe(model, devices=['cuda:0', 'cuda:1'], chunks=8)

# forward ์‹คํ–‰
input_data = torch.randn(64, 1024).to('cuda:0')
output = model(input_data)

์ด๋ ‡๊ฒŒ ํ•˜๋ฉด PyTorch๊ฐ€ ์ž๋™์œผ๋กœ stage1์€ cuda:0, stage2๋Š” cuda:1์— ์˜ฌ๋ฆฌ๊ณ , ์ž…๋ ฅ์„ 8๊ฐœ๋กœ ๋‚˜๋ˆ  ํŒŒ์ดํ”„๋ผ์ธ์„ ๋Œ๋ฆฌ๋ฉฐ GPU๋“ค์ด ์‰ฌ์ง€ ์•Š๋„๋ก ํ•™์Šต์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

 

4.3 ๋ชจ๋ธ ๋ณ‘๋ ฌํ™”๋ฅผ ์“ธ ๋•Œ ์ฃผ์˜ํ•  ์ 

  • ๋ชจ๋ธ ๋ณ‘๋ ฌ์€ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ(DDP)์ฒ˜๋Ÿผ ๊ฐ GPU๊ฐ€ ๊ฐ™์€ ๋ชจ๋ธ์„ ๋ณต์ œํ•˜๋Š” ๋ฐฉ์‹์ด ์•„๋‹ˆ๋ฏ€๋กœ, GPU ๊ฐ„ ํ†ต์‹ ์ด ์žฆ์•„ ๋„คํŠธ์›Œํฌ ๋ณ‘๋ชฉ์ด ๋ฐœ์ƒํ•˜๊ธฐ ์‰ฝ๋‹ค.
  • ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ GPU๋งˆ๋‹ค ๋‚˜๋‰˜์–ด ์žˆ์œผ๋ฏ€๋กœ, ๋ชจ๋ธ ์ €์žฅ(Checkpoint) ์‹œ ๊ฐ GPU์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๋”ฐ๋กœ ์ €์žฅํ•ด์•ผ ํ•œ๋‹ค.
  • Pipe๋ฅผ ์“ธ ๊ฒฝ์šฐ chunks๋ฅผ ์ ์ ˆํžˆ ์กฐ์ ˆํ•ด bubble(๋นˆ ์‹œ๊ฐ„)์„ ์ตœ์†Œํ™”ํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.

 

PyTorch๋Š” ๋ชจ๋ธ ๋ณ‘๋ ฌ์„ ์ˆ˜๋™์œผ๋กœ ๊ตฌํ˜„ํ•˜๊ฑฐ๋‚˜ Pipe๋ฅผ ์ด์šฉํ•ด ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌ์„ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด๋Š” ํ•˜๋‚˜์˜ GPU์— ๋‹ด๊ธฐ ์–ด๋ ค์šด ์ดˆ๋Œ€ํ˜• ๋ชจ๋ธ์„ ํ•™์Šตํ•  ๋•Œ ์œ ์šฉํ•˜๋ฉฐ, ํ•„์š”์— ๋”ฐ๋ผ Megatron-LM, Deepspeed ๊ฐ™์€ ๋” ๋ฐœ์ „๋œ ํ”„๋ ˆ์ž„์›Œํฌ๋กœ ๋„˜์–ด๊ฐˆ ์ˆ˜ ์žˆ๋‹ค.

 

 

๋ฐ˜์‘ํ˜•

'๐Ÿ› ๏ธ Engineering > Distributed Training & Inference' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

vLLM์„ ํ™œ์šฉํ•œ Large-scale AI ๋ชจ๋ธ ๊ฐ€์†ํ™” | LLM Acceleration  (0) 2025.12.16
DeepSpeed ์™„๋ฒฝ ์ดํ•ดํ•˜๊ธฐ!  (1) 2025.07.07
PyTorch FSDP (Fully Sharded Data Parallel) ์™„๋ฒฝ ์ดํ•ดํ•˜๊ธฐ!  (4) 2025.07.06
GPU ํด๋Ÿฌ์Šคํ„ฐ: SuperPOD์™€ Slurm์˜ ๊ฐœ๋…๊ณผ ํ™œ์šฉ๋ฒ•  (1) 2025.07.03
'๐Ÿ› ๏ธ Engineering/Distributed Training & Inference' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€
  • vLLM์„ ํ™œ์šฉํ•œ Large-scale AI ๋ชจ๋ธ ๊ฐ€์†ํ™” | LLM Acceleration
  • DeepSpeed ์™„๋ฒฝ ์ดํ•ดํ•˜๊ธฐ!
  • PyTorch FSDP (Fully Sharded Data Parallel) ์™„๋ฒฝ ์ดํ•ดํ•˜๊ธฐ!
  • GPU ํด๋Ÿฌ์Šคํ„ฐ: SuperPOD์™€ Slurm์˜ ๊ฐœ๋…๊ณผ ํ™œ์šฉ๋ฒ•
๋ญ…์ฆค
๋ญ…์ฆค
AI ๊ธฐ์ˆ  ๋ธ”๋กœ๊ทธ
    ๋ฐ˜์‘ํ˜•
  • ๋ญ…์ฆค
    moovzi’s Doodle
    ๋ญ…์ฆค
  • ์ „์ฒด
    ์˜ค๋Š˜
    ์–ด์ œ
  • ๊ณต์ง€์‚ฌํ•ญ

    • โœจ About Me
    • ๋ถ„๋ฅ˜ ์ „์ฒด๋ณด๊ธฐ (216)
      • ๐Ÿ“– Fundamentals (34)
        • Computer Vision (9)
        • 3D vision & Graphics (6)
        • AI & ML (16)
        • etc. (3)
      • ๐Ÿ› Research (78)
        • Deep Learning (7)
        • Perception (19)
        • OCR (7)
        • Multi-modal (8)
        • Image•Video Generation (18)
        • 3D Vision (4)
        • Material • Texture Recognit.. (8)
        • Large-scale Model (7)
        • etc. (0)
      • ๐Ÿ› ๏ธ Engineering (8)
        • Distributed Training & Infe.. (5)
        • AI & ML ์ธ์‚ฌ์ดํŠธ (3)
      • ๐Ÿ’ป Programming (92)
        • Python (18)
        • Computer Vision (12)
        • LLM (4)
        • AI & ML (18)
        • Database (3)
        • Distributed Computing (6)
        • Apache Airflow (6)
        • Docker & Kubernetes (14)
        • ์ฝ”๋”ฉ ํ…Œ์ŠคํŠธ (4)
        • etc. (7)
      • ๐Ÿ’ฌ ETC (4)
        • ์ฑ… ๋ฆฌ๋ทฐ (4)
  • ๋งํฌ

    • ๋ฆฌํ‹€๋ฆฌ ํ”„๋กœํ•„ (๋ฉ˜ํ† ๋ง, ๋ฉด์ ‘์ฑ…,...)
    • ใ€Ž๋‚˜๋Š” AI ์—”์ง€๋‹ˆ์–ด์ž…๋‹ˆ๋‹คใ€
    • Instagram
    • Brunch
    • Github
  • ์ธ๊ธฐ ๊ธ€

  • ์ตœ๊ทผ ๋Œ“๊ธ€

  • ์ตœ๊ทผ ๊ธ€

  • hELLOยท Designed By์ •์ƒ์šฐ.v4.10.3
๋ญ…์ฆค
PyTorch ๋ถ„์‚ฐ ํ•™์Šต ๊ธฐ์ดˆ: ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌํ™”, ๋ชจ๋ธ ๋ณ‘๋ ฌํ™”, ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌํ™”
์ƒ๋‹จ์œผ๋กœ

ํ‹ฐ์Šคํ† ๋ฆฌํˆด๋ฐ”