Multi-GPU Training
Multi-GPU Training (๋ค์ค GPU ํ์ต)์ ์ฌ๋ฌ ๊ฐ์ GPU๋ฅผ ์ฌ์ฉํ์ฌ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํ์ตํ๋ ๋ฐฉ๋ฒ์ด๋ค. ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ์ ์ ์ปค์ง๊ณ ๋ฐ์ดํฐ์ ์ ํฌ๊ธฐ ๋ํ ๋ฐฉ๋ํ๊ธฐ ๋๋ฌธ์ ๋ค์ค GPU๋ฅผ ์ฌ์ฉํ์ฌ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํ์ตํ๋ ๊ฒ์ ์ฌ์ค์ ํ์์ ์ธ ๊ธฐ์ ์ด๋ผ ๋ณผ ์ ์๋ค.
Pytorch์์๋ multi-gpu ํ์ต์ ์ํ ๋ช ๊ฐ์ง ๋ฐฉ๋ฒ์ ์ ๊ณตํ๋ค.
Data Prarallel (DP)
# DataParallel ๋ชจ๋๋ก ๋ชจ๋ธ ๊ฐ์ธ๊ธฐ
model = nn.DataParallel(model)
torch.nn.DataParallel ๋ชจ๋์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ผ๋ก, ๊ต์ฅํ ๊ฐ๋จํ๊ฒ ๋์ํ์ง๋ง ๋ช ๊ฐ์ง ์น๋ช ์ ์ธ ๋จ์ ์ด ์กด์ฌํ๋ ๋ฐฉ๋ฒ์ด๋ค.
- ์ฅ์
- ์์ฃผ ๊ฐ๋จํ๋ค (๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํด๋น ํจ์๋ก ๊ฐ์ธ๊ธฐ๋ง ํ๋ฉด ๋์)
- ๋จ์
- ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ์ฆ๊ฐ : ๊ฐ GPU์์ ๋ชจ๋ธ์ ๋ณต์ฌ๋ณธ์ ๋ง๋ค์ด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํ๊ธฐ GPU์ ์๊ฐ ์ฆ๊ฐํ ์๋ก ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ์ฆ๊ฐ
- ํต์ ๋ณ๋ชฉ ํ์ ๋ฐ์ : ๊ฐ GPU์์ ์ฐ์ฐ์ ์ํํ๊ณ ์ฐ์ฐ ๊ฒฐ๊ณผ๋ฅผ ํ๋์ GPU๋ก ๋ชจ์ ํ์ ๋ชจ๋ธ์ ์ ๋ฐ์ดํธํ๊ธฐ ๋๋ฌธ์ GPU ๊ฐ์ ๋ฐ์ดํฐ๋ฅผ ๋ณต์ฌํ๊ณ ํต์ ํ๋ ๋ฐ ์๊ฐ์ด ์์. ๋ํ ํ๋์ GPU๋ก ์ฐ์ฐ ๊ฒฐ๊ณผ๋ฅผ ๋ชจ์ผ๊ธฐ ๋๋ฌธ์ GPU ์๊ฐ ์ฆ๊ฐํ ์๋ก ํ๋์ GPU์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ์ฆ๊ฐํด ํจ์จ์ ์ธ ์ฌ์ฉ์ด ๋ถ๊ฐ๋ฅ.
Distributed Data Parallel (DDP)
torch.nn.parallel.DistributedDataParallel ๋ชจ๋์ ๋ถ์ฐ ํ์ต ํ๊ฒฝ์์ ์ฌ๋ฌ GPU๋ค ๊ฐ์ ํต์ ์ ์ฒ๋ฆฌํ ์ ์๋ ๊ธฐ๋ฅ์ ์ ๊ณตํ๊ธฐ ๋๋ฌธ์ ๋ค์ค GPU๋ฟ๋ง ๋ค์ค ๋จธ์ ์ ์ฌ์ฉํด์ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํ์ต์ํฌ ์๋ ์๋ค.
์ฌ๋ฌ๊ฐ์ง ๋ฉด์์ DDP๊ฐ DP๋ณด๋ค ์ฐ์ํ๊ณ ํนํ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํฐ์ ์ ํฌ๊ธฐ๊ฐ ํด์๋ก DDP๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ ๋ฆฌํ๋ค. ํ์ง๋ง DDP์ ๊ฒฝ์ฐ ๋ถ์ฐ ํ์ต์ ์ํ ์ฝ๋ ํ๊ฒฝ์ ์ธํ ํ๋ ๊ฒ์ด ์กฐ๊ธ ๋ณต์กํ๋ค๋ ๋จ์ ์ด ์๋ค.
- DDP๋ ๋ค์ค ํ๋ก์ธ์ค ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋ชจ๋ธ ๋ณต์ ๋ณธ ๊ฐ์ GIL connection ์ด์๊ฐ ์์
- ๋จ์ผ GPU ํ์ต ์ฝ๋์ ๋นํด ๋ช ๊ฐ์ง ์ถ๊ฐ/์์ ํ์
- ํ์ต ์ฝ๋๋ฅผ ํจ์ํํ๊ณ ํด๋น ํจ์๋ฅผ ๋ฉํฐํ๋ก์ธ์ฑ ๋ชจ๋๋ก ์คํํ๋ ๋ฐฉ์์ผ๋ก ๋ถ์ฐ ํ์ต์ ์งํ ๊ฐ๋ฅ
* ๋ณธ ํฌ์คํ ์์๋ ๋จ์ผ ๋จธ์ ๋ค์ค GPU ํ๊ฒฝ์ ๋ถ์ฐ ํ์ต์ ๋ํด ์ค๋ช ํ๋ค. (๋ค์ค๋จธ์ (์ฌ๋ฌ ๋์ ์ปดํจํฐ(์๋ฒ)) X)
torch.distributed.init_process_group
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://127.0.0.1:23456',
world_size=ngpus_per_node,
rank=process_id)
๋ถ์ฐ ํ์ต์ ์ํ ์ด๊ธฐํ ํจ์๋ก ๊ฐ ํ๋ก์ธ์ค๋ง๋ค ํธ์ถ๋์ด์ผ ํ๊ณ , ๋ถ์ฐ ํ์ต์ ์ํด ํ์ํ ๋ชจ๋ ์ค์ ์ด ์๋ฃ๋ ํ์๋ง ๋ค์ ๋จ๊ณ๋ก ์งํํ ์ ์๋ค. ๋ฐ๋ผ์, ๋ชจ๋ ํ๋ก์ธ์ค๊ฐ init_process_group ํจ์๋ฅผ ํธ์ถํ๊ธฐ ์ ๊น์ง๋ ์คํ์ด ์ฐจ๋จ๋๋ค.
- backend: ์ฌ์ฉํ ๋ถ์ฐ ์ฒ๋ฆฌ ๋ฐฑ์๋
- GPU training : 'NCCL'
- CPU training : 'Gloo'
- init_method: ์ด๊ธฐํ ๋ฐฉ๋ฒ์ผ๋ก 'NCCL' ๋ฐฑ์๋์ ๋จ์ผ ๋จธ์ ๋ค์ค GPU ์ฌ์ฉ ์ 'tcp://localhost:port'๋ก ์ง์
- world_size: ์ ์ฒด ํ๋ก์ธ์ค ๊ฐ์ (๋จ์ผ ๋จธ์ ์ ๊ฒฝ์ฐ GPU ๊ฐ์)
- rank: ํ์ฌ ํ๋ก์ธ์ค id. rank๋ 0๋ถํฐ world_size - 1๊น์ง์ ๊ฐ์ ๊ฐ์ง
DistributedSampler
train_sampler = DistributedSampler(dataset=train_set, shuffle=True)
batch_sampler_train = torch.utils.data.BatchSampler(train_sampler, opts.batch_size, drop_last=True)
train_loader = DataLoader(train_set, batch_sampler=batch_sampler_train, num_workers=opts.num_workers)
- DistributedSampler๋ ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌํ์ต(distributed data parallel training)์ ๊ฐ ํ๋ก์ธ์ค๊ฐ ๋ฏธ๋๋ฐฐ์น๋ฅผ ๋๋์ด ํ์ตํ ๋ฐ์ดํฐ ์ํ์ ๊ฒฐ์ ํ๋ ์ญํ
- ์ผ๋ฐ์ ์ผ๋ก ๊ฐ ํ๋ก์ธ์ค๋ ์ ์ฒด ๋ฐ์ดํฐ์ ์ ๊ณ ๋ฃจ ๋๋์ด ํ์ตํ์ง๋ง, ์ด๋ ๊ฒ ๋๋์ด ํ์ตํ๋ ๊ฒฝ์ฐ ๋ค๋ฅธ ํ๋ก์ธ์ค์์ ํ์ตํ๋ ๋ฐ์ดํฐ ์ํ๊ณผ ์ค๋ณต๋๋ ๊ฒฝ์ฐ๊ฐ ๋ฐ์ํ ์ ์์
- DistributedSampler๋ ๋ฐ์ดํฐ์ ์ ๊ฐ ์ํ์ ๋ํ ์ธ๋ฑ์ค๋ฅผ ๋ถ์ฐ์ฒ๋ฆฌ์ ๋ง๊ฒ ์๋ก์ด ์์๋ก ๋ง๋ค์ด์ฃผ๊ณ , ํด๋น ์ธ๋ฑ์ค๋ฅผ ์ด์ฉํ์ฌ ํ๋ก์ธ์ค๋ค ๊ฐ์ ์ค๋ณต ์๋ ๋ฐ์ดํฐ ๋ถ๋ฐฐ ๊ฐ๋ฅ
DistributedDataParallel
model = DistributedDataParallel(module=model, device_ids=[local_gpu_id])
- DistributedDataParallel์ ๊ฐ๊ฐ์ GPU์ ๋ฐ์ดํฐ์ ๋ชจ๋ธ์ด ๋ถ๋ฐฐ
- ๊ฐ๊ฐ์ GPU์์ ๊ณ์ฐ๋ ๊ทธ๋๋์ธํธ๋ค์ด ์ ์ฒด์ ์ผ๋ก ๋๊ธฐํ๋๋ฉฐ ํฉ์ณ์ง๋ ๋ฐฉ์์ผ๋ก ํ์ต
- ์ด๋ฅผ ํตํด ๊ฐ GPU์์ ๊ณ์ฐ๋ ๊ทธ๋๋์ธํธ๊ฐ Master GPU์์ ์ฒ๋ฆฌ๋์ด ๊ฐ์ค์น๋ฅผ ์ ๋ฐ์ดํธ(DP๋ณด๋ค ํจ์ฌ ๋ ํจ์จ์ )
torch.multiprocessing.spawn
import torch.multiprocessing as mp
def train(rank, world_size):
# ๋ถ์ฐ ํ์ต ์ฝ๋ ์์ฑ
pass
if __name__ == '__main__':
world_size = 4
mp.spawn(train, args=(world_size,), nprocs=world_size)
PyTorch์์ ์ ๊ณตํ๋ ๋ถ์ฐ ํ์ต์ ์ํ ํ๋ก์ธ์ค ๊ทธ๋ฃน์ ์์ฑํ๋ ํจ์๋ก ๋ถ์ฐ ํ์ต์ ์ํด ์ฌ๋ฌ ํ๋ก์ธ์ค๋ฅผ ์คํํ๋ ๋ฐ ์ฌ์ฉ
- fn : ์คํํ ํจ์. ํจ์์ ์ฒซ ๋ฒ์งธ ํ๋ผ๋ฏธํฐ๋ rank๋ก ์ง์ .
- args : ํจ์์ ์ ๋ฌํ ์ธ์๋ฅผ ์ง์ . (ํจ์์ ์ฒซ ๋ฒ์งธ ํ๋ผ๋ฏธํฐ์ธ rank๋ ์ ์ธ)
- nprocs : ์คํํ ํ๋ก์ธ์ค ๊ฐ์
์ ๋ฆฌํด๋ณด๋ฉด 'torch.distributed.init_process_group'์ผ๋ก ๋ถ์ฐ ํ์ต ํ๊ฒฝ์ ์ด๊ธฐํ, ๋ฐ์ดํฐ๊ฐ ํ๋ก์ธ์ค ๊ฐ ์ค๋ณต๋์ง ์๋๋ก 'DistributedSampler'๋ฅผ ์ฌ์ฉ, ๋ชจ๋ธ์ 'DistributedDataParallel'๋ก wrappingํ ํ์ต ์ฝ๋๋ฅผ ํจ์๋ก ๊ตฌ์ฑํ๊ณ (์ฒซ ๋ฒ์งธ ํ๋ผ๋ฏธํฐ๋ rank), ๊ตฌ์ฑ๋ ํจ์๋ฅผ 'torch.multiprocessing.spawn'๋ฅผ ์ฌ์ฉํด์ ์คํ์ํค๋ฉด ๋๋ค๋ ๋ป์ด๋ค.
DDP ์ฝ๋ ์์
- ๋ชจ๋ธ : resnet18 ๋ชจ๋ธ
- ๋ฐ์ดํฐ์ : cifar10
- ๋จ์ผ ๋จธ์ ๋ค์ค GPU ์์
- torch.multiprocessing.spawn, torch.distributed.init_process_group, DistributedSampler,DistributedDataParallel ๊ฐ ๋ชจ๋ ์ ์ฉ๋ ์์๋ก ๊ฐ์ ํ๊ฒฝ์ ๋ง๊ฒ ์์ ํด์ ์ฌ์ฉ ๊ฐ๋ฅ
import argparse
import torch
import torchvision.transforms as transforms
from torchvision.datasets.cifar import CIFAR10
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torchvision.models import resnet18
def get_args_parser():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--epoch', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--port', type=int, default=2033)
parser.add_argument('--root', type=str, default='./cifar')
parser.add_argument('--local_rank', type=int)
return parser
def init_distributed_training(rank, opts):
# 1. setting for distributed training
opts.rank = rank
opts.gpu = opts.rank % torch.cuda.device_count()
local_gpu_id = int(opts.gpu_ids[opts.rank])
torch.cuda.set_device(local_gpu_id)
if opts.rank is not None:
print("Use GPU: {} for training".format(local_gpu_id))
# 2. init_process_group
torch.distributed.init_process_group(backend='nccl',
init_method='tcp://127.0.0.1:' + str(opts.port),
world_size=opts.ngpus_per_node,
rank=opts.rank)
# if put this function, the all processes block at all.
torch.distributed.barrier()
# convert print fn iif rank is zero
setup_for_distributed(opts.rank == 0)
print('opts :',opts)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def main(rank, opts):
init_distributed_training(rank, opts)
local_gpu_id = opts.gpu
train_set = CIFAR10(root=opts.root,
train=True,
transform=transforms.ToTensor(),
download=True)
train_sampler = DistributedSampler(dataset=train_set, shuffle=True)
batch_sampler_train = torch.utils.data.BatchSampler(train_sampler, opts.batch_size, drop_last=True)
train_loader = DataLoader(train_set, batch_sampler=batch_sampler_train, num_workers=opts.num_workers)
model = resnet18(pretrained=False)
model = model.cuda(local_gpu_id)
model = DistributedDataParallel(module=model, device_ids=[local_gpu_id])
criterion = torch.nn.CrossEntropyLoss().to(local_gpu_id)
optimizer = torch.optim.SGD(params=model.parameters(),
lr=0.01,
weight_decay=0.0005,
momentum=0.9)
print(f'[INFO] : ํ์ต ์์')
for epoch in range(opts.epoch):
model.train()
train_sampler.set_epoch(epoch)
for i, (images, labels) in enumerate(train_loader):
images = images.to(local_gpu_id)
labels = labels.to(local_gpu_id)
outputs = model(images)
optimizer.zero_grad()
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'[INFO] : {epoch} ๋ฒ์งธ epoch ์๋ฃ')
print(f'[INFO] : Distributed ํ์ต ํ
์คํธ์๋ฃ')
if __name__ == '__main__':
parser = argparse.ArgumentParser('Distributed training test', parents=[get_args_parser()])
opts = parser.parse_args()
opts.ngpus_per_node = torch.cuda.device_count()
opts.gpu_ids = list(range(opts.ngpus_per_node))
opts.num_workers = opts.ngpus_per_node * 4
torch.multiprocessing.spawn(main,
args=(opts,),
nprocs=opts.ngpus_per_node,
join=True)