๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[pytorch] Multi-GPU Training | ๋‹ค์ค‘ GPU ํ•™์Šต ์˜ˆ์‹œ| Distributed Data Parallel (DDP) | Data Parallel (DP)

by ๋ญ…์ฆค 2023. 4. 17.
๋ฐ˜์‘ํ˜•
Multi-GPU Training

 

Multi-GPU Training (๋‹ค์ค‘ GPU ํ•™์Šต)์€ ์—ฌ๋Ÿฌ ๊ฐœ์˜ GPU๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•์ด๋‹ค. ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์€ ์ ์  ์ปค์ง€๊ณ  ๋ฐ์ดํ„ฐ์…‹์˜ ํฌ๊ธฐ ๋˜ํ•œ ๋ฐฉ๋Œ€ํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋‹ค์ค‘ GPU๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์€ ์‚ฌ์‹ค์ƒ ํ•„์ˆ˜์ ์ธ ๊ธฐ์ˆ ์ด๋ผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค. 

 

Pytorch์—์„œ๋Š” multi-gpu ํ•™์Šต์„ ์œ„ํ•œ ๋ช‡ ๊ฐ€์ง€ ๋ฐฉ๋ฒ•์„ ์ œ๊ณตํ•œ๋‹ค.

 

Data Prarallel (DP)
# DataParallel ๋ชจ๋“ˆ๋กœ ๋ชจ๋ธ ๊ฐ์‹ธ๊ธฐ
model = nn.DataParallel(model)

 

torch.nn.DataParallel ๋ชจ๋“ˆ์„ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์œผ๋กœ, ๊ต‰์žฅํžˆ ๊ฐ„๋‹จํ•˜๊ฒŒ ๋™์ž‘ํ•˜์ง€๋งŒ ๋ช‡ ๊ฐ€์ง€ ์น˜๋ช…์ ์ธ ๋‹จ์ ์ด ์กด์žฌํ•˜๋Š” ๋ฐฉ๋ฒ•์ด๋‹ค.

 

  • ์žฅ์ 
    • ์•„์ฃผ ๊ฐ„๋‹จํ•˜๋‹ค (๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ํ•ด๋‹น ํ•จ์ˆ˜๋กœ ๊ฐ์‹ธ๊ธฐ๋งŒ ํ•˜๋ฉด ๋™์ž‘)
  • ๋‹จ์ 
    • ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ฆ๊ฐ€ : ๊ฐ GPU์—์„œ ๋ชจ๋ธ์˜ ๋ณต์‚ฌ๋ณธ์„ ๋งŒ๋“ค์–ด ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ GPU์˜ ์ˆ˜๊ฐ€ ์ฆ๊ฐ€ํ•  ์ˆ˜๋ก ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์ด ์ฆ๊ฐ€
    • ํ†ต์‹  ๋ณ‘๋ชฉ ํ˜„์ƒ ๋ฐœ์ƒ :  ๊ฐ GPU์—์„œ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๊ณ  ์—ฐ์‚ฐ ๊ฒฐ๊ณผ๋ฅผ ํ•˜๋‚˜์˜ GPU๋กœ ๋ชจ์€ ํ›„์— ๋ชจ๋ธ์„ ์—…๋ฐ์ดํŠธํ•˜๊ธฐ ๋•Œ๋ฌธ์— GPU ๊ฐ„์— ๋ฐ์ดํ„ฐ๋ฅผ ๋ณต์‚ฌํ•˜๊ณ  ํ†ต์‹ ํ•˜๋Š” ๋ฐ ์‹œ๊ฐ„์ด ์†Œ์š”. ๋˜ํ•œ ํ•˜๋‚˜์˜ GPU๋กœ ์—ฐ์‚ฐ ๊ฒฐ๊ณผ๋ฅผ ๋ชจ์œผ๊ธฐ ๋•Œ๋ฌธ์— GPU ์ˆ˜๊ฐ€ ์ฆ๊ฐ€ํ•  ์ˆ˜๋ก ํ•˜๋‚˜์˜ GPU์˜ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์ด ์ฆ๊ฐ€ํ•ด ํšจ์œจ์ ์ธ ์‚ฌ์šฉ์ด ๋ถˆ๊ฐ€๋Šฅ.

 

 

Distributed Data Parallel (DDP)

 

torch.nn.parallel.DistributedDataParallel ๋ชจ๋“ˆ์„ ๋ถ„์‚ฐ ํ•™์Šต ํ™˜๊ฒฝ์—์„œ ์—ฌ๋Ÿฌ GPU๋“ค ๊ฐ„์˜ ํ†ต์‹ ์„ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋Š” ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋‹ค์ค‘ GPU๋ฟ๋งŒ ๋‹ค์ค‘ ๋จธ์‹ ์„ ์‚ฌ์šฉํ•ด์„œ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ํ•™์Šต์‹œํ‚ฌ ์ˆ˜๋„ ์žˆ๋‹ค. 

 

์—ฌ๋Ÿฌ๊ฐ€์ง€ ๋ฉด์—์„œ DDP๊ฐ€ DP๋ณด๋‹ค ์šฐ์ˆ˜ํ•˜๊ณ  ํŠนํžˆ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํ„ฐ์…‹์˜ ํฌ๊ธฐ๊ฐ€ ํด์ˆ˜๋ก DDP๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์œ ๋ฆฌํ•˜๋‹ค. ํ•˜์ง€๋งŒ DDP์˜ ๊ฒฝ์šฐ ๋ถ„์‚ฐ ํ•™์Šต์„ ์œ„ํ•œ ์ฝ”๋“œ ํ™˜๊ฒฝ์„ ์„ธํŒ…ํ•˜๋Š” ๊ฒƒ์ด ์กฐ๊ธˆ ๋ณต์žกํ•˜๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ๋‹ค.

 

  • DDP๋Š” ๋‹ค์ค‘ ํ”„๋กœ์„ธ์Šค ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ชจ๋ธ ๋ณต์ œ๋ณธ ๊ฐ„์˜ GIL connection ์ด์Šˆ๊ฐ€ ์—†์Œ
  • ๋‹จ์ผ GPU ํ•™์Šต ์ฝ”๋“œ์— ๋น„ํ•ด ๋ช‡ ๊ฐ€์ง€ ์ถ”๊ฐ€/์ˆ˜์ • ํ•„์š”
  • ํ•™์Šต ์ฝ”๋“œ๋ฅผ ํ•จ์ˆ˜ํ™”ํ•˜๊ณ  ํ•ด๋‹น ํ•จ์ˆ˜๋ฅผ ๋ฉ€ํ‹ฐํ”„๋กœ์„ธ์‹ฑ ๋ชจ๋“ˆ๋กœ ์‹คํ–‰ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ๋ถ„์‚ฐ ํ•™์Šต์„ ์ง„ํ–‰ ๊ฐ€๋Šฅ

 

* ๋ณธ ํฌ์ŠคํŒ…์—์„œ๋Š” ๋‹จ์ผ ๋จธ์‹  ๋‹ค์ค‘ GPU ํ™˜๊ฒฝ์˜ ๋ถ„์‚ฐ ํ•™์Šต์— ๋Œ€ํ•ด ์„ค๋ช…ํ•œ๋‹ค. (๋‹ค์ค‘๋จธ์‹ (์—ฌ๋Ÿฌ ๋Œ€์˜ ์ปดํ“จํ„ฐ(์„œ๋ฒ„)) X)

 

torch.distributed.init_process_group
torch.distributed.init_process_group(backend='nccl',
                            init_method='tcp://127.0.0.1:23456',
                            world_size=ngpus_per_node,
                            rank=process_id)

๋ถ„์‚ฐ ํ•™์Šต์„ ์œ„ํ•œ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜๋กœ ๊ฐ ํ”„๋กœ์„ธ์Šค๋งˆ๋‹ค ํ˜ธ์ถœ๋˜์–ด์•ผ ํ•˜๊ณ , ๋ถ„์‚ฐ ํ•™์Šต์„ ์œ„ํ•ด ํ•„์š”ํ•œ ๋ชจ๋“  ์„ค์ •์ด ์™„๋ฃŒ๋œ ํ›„์—๋งŒ ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰ํ•  ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ, ๋ชจ๋“  ํ”„๋กœ์„ธ์Šค๊ฐ€ init_process_group ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜๊ธฐ ์ „๊นŒ์ง€๋Š” ์‹คํ–‰์ด ์ฐจ๋‹จ๋œ๋‹ค.

  • backend: ์‚ฌ์šฉํ•  ๋ถ„์‚ฐ ์ฒ˜๋ฆฌ ๋ฐฑ์—”๋“œ
    • GPU training : 'NCCL'
    • CPU training : 'Gloo'
  • init_method: ์ดˆ๊ธฐํ™” ๋ฐฉ๋ฒ•์œผ๋กœ 'NCCL' ๋ฐฑ์—”๋“œ์— ๋‹จ์ผ ๋จธ์‹  ๋‹ค์ค‘ GPU ์‚ฌ์šฉ ์‹œ 'tcp://localhost:port'๋กœ ์ง€์ •
  • world_size: ์ „์ฒด ํ”„๋กœ์„ธ์Šค ๊ฐœ์ˆ˜ (๋‹จ์ผ ๋จธ์‹ ์˜ ๊ฒฝ์šฐ GPU ๊ฐœ์ˆ˜)
  • rank: ํ˜„์žฌ ํ”„๋กœ์„ธ์Šค id. rank๋Š” 0๋ถ€ํ„ฐ world_size - 1๊นŒ์ง€์˜ ๊ฐ’์„ ๊ฐ€์ง

 

 

DistributedSampler 
train_sampler = DistributedSampler(dataset=train_set, shuffle=True)
batch_sampler_train = torch.utils.data.BatchSampler(train_sampler, opts.batch_size, drop_last=True)
train_loader = DataLoader(train_set, batch_sampler=batch_sampler_train, num_workers=opts.num_workers)
  • DistributedSampler๋Š” ๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌํ•™์Šต(distributed data parallel training)์‹œ ๊ฐ ํ”„๋กœ์„ธ์Šค๊ฐ€ ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋ฅผ ๋‚˜๋ˆ„์–ด ํ•™์Šตํ•  ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ์„ ๊ฒฐ์ •ํ•˜๋Š” ์—ญํ• 
  • ์ผ๋ฐ˜์ ์œผ๋กœ ๊ฐ ํ”„๋กœ์„ธ์Šค๋Š” ์ „์ฒด ๋ฐ์ดํ„ฐ์…‹์„ ๊ณ ๋ฃจ ๋‚˜๋ˆ„์–ด ํ•™์Šตํ•˜์ง€๋งŒ, ์ด๋ ‡๊ฒŒ ๋‚˜๋ˆ„์–ด ํ•™์Šตํ•˜๋Š” ๊ฒฝ์šฐ ๋‹ค๋ฅธ ํ”„๋กœ์„ธ์Šค์—์„œ ํ•™์Šตํ•˜๋Š” ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ๊ณผ ์ค‘๋ณต๋˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Œ
  • DistributedSampler๋Š” ๋ฐ์ดํ„ฐ์…‹์˜ ๊ฐ ์ƒ˜ํ”Œ์— ๋Œ€ํ•œ ์ธ๋ฑ์Šค๋ฅผ ๋ถ„์‚ฐ์ฒ˜๋ฆฌ์— ๋งž๊ฒŒ ์ƒˆ๋กœ์šด ์ˆœ์„œ๋กœ ๋งŒ๋“ค์–ด์ฃผ๊ณ , ํ•ด๋‹น ์ธ๋ฑ์Šค๋ฅผ ์ด์šฉํ•˜์—ฌ ํ”„๋กœ์„ธ์Šค๋“ค ๊ฐ„์˜ ์ค‘๋ณต ์—†๋Š” ๋ฐ์ดํ„ฐ ๋ถ„๋ฐฐ ๊ฐ€๋Šฅ

 

 

DistributedDataParallel

Collective Communication

model = DistributedDataParallel(module=model, device_ids=[local_gpu_id])
  • DistributedDataParallel์€ ๊ฐ๊ฐ์˜ GPU์— ๋ฐ์ดํ„ฐ์™€ ๋ชจ๋ธ์ด ๋ถ„๋ฐฐ
  • ๊ฐ๊ฐ์˜ GPU์—์„œ ๊ณ„์‚ฐ๋œ ๊ทธ๋ž˜๋””์–ธํŠธ๋“ค์ด ์ „์ฒด์ ์œผ๋กœ ๋™๊ธฐํ™”๋˜๋ฉฐ ํ•ฉ์ณ์ง€๋Š” ๋ฐฉ์‹์œผ๋กœ ํ•™์Šต
  • ์ด๋ฅผ ํ†ตํ•ด ๊ฐ GPU์—์„œ ๊ณ„์‚ฐ๋œ ๊ทธ๋ž˜๋””์–ธํŠธ๊ฐ€ Master GPU์—์„œ ์ฒ˜๋ฆฌ๋˜์–ด ๊ฐ€์ค‘์น˜๋ฅผ ์—…๋ฐ์ดํŠธ(DP๋ณด๋‹ค ํ›จ์”ฌ ๋” ํšจ์œจ์ )

 

torch.multiprocessing.spawn
import torch.multiprocessing as mp

def train(rank, world_size):
    # ๋ถ„์‚ฐ ํ•™์Šต ์ฝ”๋“œ ์ž‘์„ฑ
    pass

if __name__ == '__main__':
    world_size = 4
    mp.spawn(train, args=(world_size,), nprocs=world_size)

PyTorch์—์„œ ์ œ๊ณตํ•˜๋Š” ๋ถ„์‚ฐ ํ•™์Šต์„ ์œ„ํ•œ ํ”„๋กœ์„ธ์Šค ๊ทธ๋ฃน์„ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜๋กœ ๋ถ„์‚ฐ ํ•™์Šต์„ ์œ„ํ•ด ์—ฌ๋Ÿฌ ํ”„๋กœ์„ธ์Šค๋ฅผ ์‹คํ–‰ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ

  • fn : ์‹คํ–‰ํ•  ํ•จ์ˆ˜. ํ•จ์ˆ˜์˜ ์ฒซ ๋ฒˆ์งธ ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” rank๋กœ ์ง€์ •.
  • args : ํ•จ์ˆ˜์— ์ „๋‹ฌํ•  ์ธ์ž๋ฅผ ์ง€์ •. (ํ•จ์ˆ˜์˜ ์ฒซ ๋ฒˆ์งธ ํŒŒ๋ผ๋ฏธํ„ฐ์ธ rank๋Š” ์ œ์™ธ)
  • nprocs : ์‹คํ–‰ํ•  ํ”„๋กœ์„ธ์Šค ๊ฐœ์ˆ˜

 

์ •๋ฆฌํ•ด๋ณด๋ฉด 'torch.distributed.init_process_group'์œผ๋กœ ๋ถ„์‚ฐ ํ•™์Šต ํ™˜๊ฒฝ์„ ์ดˆ๊ธฐํ™”, ๋ฐ์ดํ„ฐ๊ฐ€ ํ”„๋กœ์„ธ์Šค ๊ฐ„ ์ค‘๋ณต๋˜์ง€ ์•Š๋„๋ก 'DistributedSampler'๋ฅผ ์‚ฌ์šฉ, ๋ชจ๋ธ์€ 'DistributedDataParallel'๋กœ wrappingํ•œ ํ•™์Šต ์ฝ”๋“œ๋ฅผ ํ•จ์ˆ˜๋กœ ๊ตฌ์„ฑํ•˜๊ณ (์ฒซ ๋ฒˆ์งธ ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” rank), ๊ตฌ์„ฑ๋œ ํ•จ์ˆ˜๋ฅผ 'torch.multiprocessing.spawn'๋ฅผ ์‚ฌ์šฉํ•ด์„œ ์‹คํ–‰์‹œํ‚ค๋ฉด ๋œ๋‹ค๋Š” ๋œป์ด๋‹ค.

 

DDP ์ฝ”๋“œ ์˜ˆ์‹œ
  • ๋ชจ๋ธ : resnet18 ๋ชจ๋ธ
  • ๋ฐ์ดํ„ฐ์…‹ : cifar10
  • ๋‹จ์ผ ๋จธ์‹  ๋‹ค์ค‘ GPU ์˜ˆ์‹œ
  • torch.multiprocessing.spawn, torch.distributed.init_process_group, DistributedSampler,DistributedDataParallel ๊ฐ€ ๋ชจ๋‘ ์ ์šฉ๋œ ์˜ˆ์‹œ๋กœ ๊ฐ์ž ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ์ˆ˜์ •ํ•ด์„œ ์‚ฌ์šฉ ๊ฐ€๋Šฅ
import argparse

import torch
import torchvision.transforms as transforms
from torchvision.datasets.cifar import CIFAR10

from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from torch.nn.parallel import DistributedDataParallel
from torchvision.models import resnet18

def get_args_parser():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('--epoch', type=int, default=3)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--port', type=int, default=2033)
    parser.add_argument('--root', type=str, default='./cifar')
    parser.add_argument('--local_rank', type=int)
    return parser

def init_distributed_training(rank, opts):
    # 1. setting for distributed training
    opts.rank = rank
    opts.gpu = opts.rank % torch.cuda.device_count()
    local_gpu_id = int(opts.gpu_ids[opts.rank])
    torch.cuda.set_device(local_gpu_id)
    
    if opts.rank is not None:
        print("Use GPU: {} for training".format(local_gpu_id))

    # 2. init_process_group
    torch.distributed.init_process_group(backend='nccl',
                            init_method='tcp://127.0.0.1:' + str(opts.port),
                            world_size=opts.ngpus_per_node,
                            rank=opts.rank)

    # if put this function, the all processes block at all.
    torch.distributed.barrier()

    # convert print fn iif rank is zero
    setup_for_distributed(opts.rank == 0)
    print('opts :',opts)


def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


def main(rank, opts):
    init_distributed_training(rank, opts)
    local_gpu_id = opts.gpu

    train_set = CIFAR10(root=opts.root,
                        train=True,
                        transform=transforms.ToTensor(),
                        download=True)

    train_sampler = DistributedSampler(dataset=train_set, shuffle=True)
    
    batch_sampler_train = torch.utils.data.BatchSampler(train_sampler, opts.batch_size, drop_last=True)
    train_loader = DataLoader(train_set, batch_sampler=batch_sampler_train, num_workers=opts.num_workers)

    model = resnet18(pretrained=False)
    model = model.cuda(local_gpu_id)
    model = DistributedDataParallel(module=model, device_ids=[local_gpu_id])

    criterion = torch.nn.CrossEntropyLoss().to(local_gpu_id)
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=0.01,
                                weight_decay=0.0005,
                                momentum=0.9)

    print(f'[INFO] : ํ•™์Šต ์‹œ์ž‘')
    for epoch in range(opts.epoch):

        model.train()
        train_sampler.set_epoch(epoch)

        for i, (images, labels) in enumerate(train_loader):
            images = images.to(local_gpu_id)
            labels = labels.to(local_gpu_id)
            outputs = model(images)

            optimizer.zero_grad()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        print(f'[INFO] : {epoch} ๋ฒˆ์งธ epoch ์™„๋ฃŒ')

    print(f'[INFO] : Distributed ํ•™์Šต ํ…Œ์ŠคํŠธ์™„๋ฃŒ')

if __name__ == '__main__':

    parser = argparse.ArgumentParser('Distributed training test', parents=[get_args_parser()])
    opts = parser.parse_args()
    opts.ngpus_per_node = torch.cuda.device_count()
    opts.gpu_ids = list(range(opts.ngpus_per_node))
    opts.num_workers = opts.ngpus_per_node * 4

    torch.multiprocessing.spawn(main,
             args=(opts,),
             nprocs=opts.ngpus_per_node,
             join=True)
๋ฐ˜์‘ํ˜•