PyTorch 분산 학습 기초: 데이터 병렬화, 모델 병렬화, 파이프라인 병렬화
·
🛠️ Engineering/Distributed Training
딥러닝 모델이 점점 커지고 데이터도 방대해지면서, 단일 GPU나 서버만으로는 학습 속도가 너무 느리거나 GPU 메모리가 부족해 학습이 불가능해진다. 이를 해결하기 위해 여러 GPU를 동시에 활용해 모델을 학습시키는 것이 바로 분산 학습이다.1. 분산 학습 종류1.1 데이터 병렬화(Data Parallelism)[전체 데이터] → [분할된 미니배치1] → GPU0 (모델 복제) → [분할된 미니배치2] → GPU1 (모델 복제) → [분할된 미니배치3] → GPU2 (모델 복제)[각 GPU] → forward & backward → all-reduce → 동기화 → 파라미터 업데이트 가장 보편적으로 사용되는 방식이다. 동일한 모델을 여러 GPU에 복제하고, 미니배치 ..