Mixed Precision
์ผ๋ฐ์ ์ธ neural network์์๋ 32-bit floating point(FP32) precision์ ์ด์ฉํ์ฌ ํ์ต์ ์ํค๋๋ฐ, ์ต์ ํ๋์จ์ด์์๋ lower precision(FP16) ๊ณ์ฐ์ด ์ง์๋๋ฉด์ ์๋์์ ์ด์ ์ ์ป์ ์ ์๋ค. ํ์ง๋ง FP16์ผ๋ก precision์ ์ค์ด๋ฉด ์๋ฅผ ํํํ๋ ๋ฒ์๊ฐ ์ค์ด๋ค์ด ํ์ต ์ฑ๋ฅ์ด ์ ํ๋ ์ ์๋ค.
Mixed Precision์ ๋ฅ๋ฌ๋ ๋ชจ๋ธ ํ์ต ๊ณผ์ ์์ ๋ถ๋์์์ ์ฐ์ฐ์ ์ ๋ฐ๋๋ฅผ ํผํฉํ์ฌ ์ฌ์ฉํ๋ ๊ธฐ์ ๋ก, ํ์ต ์๋๋ฅผ ๋์ด๊ณ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ด๋ ๋ฐ ๋์์ ์ค๋ค. Mixed Precision์ ๋๊ฐ FP32(32๋นํธ ๋ถ๋์์์ )์ FP16(16๋นํธ ๋ถ๋์์์ )์ ์กฐํฉํ์ฌ ์ฌ์ฉํ๋ฉฐ, ๊ฐ์ค์น์ ๊ทธ๋๋์ธํธ๋ FP16๋ก ์ ์ฅํ๊ณ ์ฐ์ฐ์ FP16๋ก ์ํํ๋ฉด์ ์ผ๋ถ ์ฐ์ฐ์์๋ FP32๋ก ์ ํํ์ฌ ์ค๋ฒํ๋ก์ฐ ๋ฐ ์ธ๋ํ๋ก์ฐ๋ฅผ ๋ฐฉ์งํ๋ค.
Mixed Precision์ ์ฌ์ฉํ๋ฉด ์ผ๋ฐ์ ์ผ๋ก ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ ์ฝํ๊ณ ํ์ต ์๋๋ฅผ ํฅ์์ํค๋ฉด์ ์ ํ๋๋ ์ด๋ ์ ๋ ์ ์งํ ์ ์๊ธฐ์ ๋ชจ๋ธ ํ์ต ์ ์์ฃผ ์ฌ์ฉ๋๋ค.
AUTOMATIC MIXED PRECISION PACKAGE - TORCH.AMP
torch.amp์ torch.autocast๋ฅผ ์ฌ์ฉํ๋ฉด pytorch ๋ชจ๋ธ์ mixed precision์ ๊ฐ๋จํ๊ฒ ์ ์ฉํ ์ ์๋ค.
Docs url :https://pytorch.org/docs/stable/amp.html
์ ์ค๋ช ์์, ops๋ ์ ํ๋๋ฅผ ์ ์งํ๋ฉด์ ์ฑ๋ฅ์ ํฅ์์ํค๊ธฐ ์ํด autocast์์ ์ ํํ ํน์ op dtype์์๋ง ์คํ๋๋ค๊ณ ๋์ ์๋ค. Autocast Op Reference๋ฅผ ๋ณด๋ฉด ์์ธํ ์ค๋ช ์ ํ์ธํ ์ ์๋ค.
๋ํ autocast๋ ์๋์ฒ๋ผ forward ๋ฉ์๋์ ๋ฐ์ฝ๋ ์ดํฐ๋ก ์ฌ์ฉํ ์๋ ์์ต๋๋ค.
torch.Autocast() ์ฌ์ฉ ์์
# ํ์ต
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
model = YourModel()
optimizer = YourOptimizer(model.parameters(), lr=learning_rate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
scaler = GradScaler()
for epoch in range(num_epochs):
for inputs, targets in dataloader:
inputs, targets = inputs.to(device), targets.to(device)
# autocast๋ฅผ ์ฌ์ฉํ์ฌ Mixed Precision ํ์ฑํ
with autocast():
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
# GradScaler๋ฅผ ์ฌ์ฉํ์ฌ ๊ทธ๋๋์ธํธ ์ค์ผ์ผ๋ง
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
๋ชจ๋ธ ํ์ต ์์ mixed precision์ ์ฌ์ฉํ๊ธฐ ์ํด์๋ ํ์ต ๊ณผ์ ์์ ์์ ๊ฐ์ด with autocat()๋ฅผ ์ฌ์ฉํด์ผ ํ๋ค
# ์ธํผ๋ฐ์ค
import torch
from torch.cuda.amp import autocast
model = YourModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval() # ๋ชจ๋ธ์ evaluation ๋ชจ๋๋ก ์ค์
with torch.no_grad():
with autocast():
inputs = inputs.to(device)
outputs = model(inputs)
๋ชจ๋ธ ์ธํผ๋ฐ์ค ์ mixed precision์ ์ฌ์ฉํ๊ธฐ ์ํด์๋ ์ ์ฝ๋์ฒ๋ผ ์ธํผ๋ฐ์ค ์ฝ๋์์ with autocast()๋ฅผ ์ฌ์ฉํด๋ ๋๊ณ ,
import torch
from torch import nn
from torch.cuda.amp import autocast
class YourModel(nn.Module):
def __init__(self):
super(YourModel, self).__init__()
# ๋ชจ๋ธ์ ๋ ์ด์ด๋ค์ ์ ์ํฉ๋๋ค.
def forward(self, inputs):
with autocast():
# forward ๋ฉ์๋ ๋ด์ ์ฐ์ฐ์ Mixed Precision์ผ๋ก ์ํํฉ๋๋ค.
x = self.layer1(inputs)
x = self.layer2(x)
x = self.layer3(x)
# ...
return x
์์ ๊ฐ์ด ๋ชจ๋ธ forward ๋ฉ์๋์์ with autocast()๋ฅผ ์ฌ์ฉํด๋ ๋๋ค.
'๐ป Programming > AI & ML' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[HuggingFace] Swin Transformer ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ชจ๋ธ ํ์ต ํํ ๋ฆฌ์ผ (0) | 2023.01.11 |
---|---|
[ONNX] pytorch ๋ชจ๋ธ์ ONNX๋ก ๋ณํํ๊ณ ์คํํ๊ธฐ (0) | 2022.12.21 |
[pytorch] pytorch ๋ชจ๋ธ ๋ก๋ ์ค Missing key(s) in state_dict ์๋ฌ (0) | 2022.12.15 |
[pytorch] COCO Data Format ์ ์ฉ Custom Dataset ์์ฑ (1) | 2022.06.04 |
[pytorch] model ์ ์ ๊ทผํ๊ธฐ, ํน์ layer ๋ณ๊ฒฝํ๊ธฐ (0) | 2022.01.05 |