๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[pytorch] Mixed Precision ์‚ฌ์šฉ ๋ฐฉ๋ฒ• | torch.amp | torch.autocast | ๋ชจ๋ธ ํ•™์Šต ์†๋„๋ฅผ ๋†’์ด๊ณ  ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํšจ์œจ์ ์œผ๋กœ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•

by ๋ญ…์ฆค 2022. 12. 20.
๋ฐ˜์‘ํ˜•
Mixed Precision

์ผ๋ฐ˜์ ์ธ neural network์—์„œ๋Š” 32-bit floating point(FP32) precision์„ ์ด์šฉํ•˜์—ฌ ํ•™์Šต์„ ์‹œํ‚ค๋Š”๋ฐ, ์ตœ์‹  ํ•˜๋“œ์›จ์–ด์—์„œ๋Š” lower precision(FP16) ๊ณ„์‚ฐ์ด ์ง€์›๋˜๋ฉด์„œ ์†๋„์—์„œ ์ด์ ์„ ์–ป์„ ์ˆ˜ ์žˆ๋‹ค. ํ•˜์ง€๋งŒ FP16์œผ๋กœ precision์„ ์ค„์ด๋ฉด ์ˆ˜๋ฅผ ํ‘œํ˜„ํ•˜๋Š” ๋ฒ”์œ„๊ฐ€ ์ค„์–ด๋“ค์–ด ํ•™์Šต ์„ฑ๋Šฅ์ด ์ €ํ•˜๋  ์ˆ˜ ์žˆ๋‹ค.

 

Mixed Precision์€ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ํ•™์Šต ๊ณผ์ •์—์„œ ๋ถ€๋™์†Œ์ˆ˜์  ์—ฐ์‚ฐ์˜ ์ •๋ฐ€๋„๋ฅผ ํ˜ผํ•ฉํ•˜์—ฌ ์‚ฌ์šฉํ•˜๋Š” ๊ธฐ์ˆ ๋กœ, ํ•™์Šต ์†๋„๋ฅผ ๋†’์ด๊ณ  ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ด๋Š” ๋ฐ ๋„์›€์„ ์ค€๋‹ค. Mixed Precision์€ ๋Œ€๊ฐœ FP32(32๋น„ํŠธ ๋ถ€๋™์†Œ์ˆ˜์ )์™€ FP16(16๋น„ํŠธ ๋ถ€๋™์†Œ์ˆ˜์ )์„ ์กฐํ•ฉํ•˜์—ฌ ์‚ฌ์šฉํ•˜๋ฉฐ, ๊ฐ€์ค‘์น˜์™€ ๊ทธ๋ž˜๋””์–ธํŠธ๋Š” FP16๋กœ ์ €์žฅํ•˜๊ณ  ์—ฐ์‚ฐ์€ FP16๋กœ ์ˆ˜ํ–‰ํ•˜๋ฉด์„œ ์ผ๋ถ€ ์—ฐ์‚ฐ์—์„œ๋Š” FP32๋กœ ์ „ํ™˜ํ•˜์—ฌ ์˜ค๋ฒ„ํ”Œ๋กœ์šฐ ๋ฐ ์–ธ๋”ํ”Œ๋กœ์šฐ๋ฅผ ๋ฐฉ์ง€ํ•œ๋‹ค.

 

Mixed Precision์„ ์‚ฌ์šฉํ•˜๋ฉด ์ผ๋ฐ˜์ ์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ ˆ์•ฝํ•˜๊ณ  ํ•™์Šต ์†๋„๋ฅผ ํ–ฅ์ƒ์‹œํ‚ค๋ฉด์„œ ์ •ํ™•๋„๋Š” ์–ด๋Š ์ •๋„ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ๊ธฐ์— ๋ชจ๋ธ ํ•™์Šต ์‹œ ์ž์ฃผ ์‚ฌ์šฉ๋œ๋‹ค.

 

AUTOMATIC MIXED PRECISION PACKAGE - TORCH.AMP

torch.amp์˜ torch.autocast๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด pytorch ๋ชจ๋ธ์— mixed precision์„ ๊ฐ„๋‹จํ•˜๊ฒŒ ์ ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

 

Docs url :https://pytorch.org/docs/stable/amp.html

torch.autocast

์œ„ ์„ค๋ช…์—์„œ, ops๋Š” ์ •ํ™•๋„๋ฅผ ์œ ์ง€ํ•˜๋ฉด์„œ ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œํ‚ค๊ธฐ ์œ„ํ•ด autocast์—์„œ ์„ ํƒํ•œ ํŠน์ • op dtype์—์„œ๋งŒ ์‹คํ–‰๋œ๋‹ค๊ณ  ๋‚˜์™€ ์žˆ๋‹ค. Autocast Op Reference๋ฅผ ๋ณด๋ฉด ์ž์„ธํ•œ ์„ค๋ช…์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

 

๋˜ํ•œ autocast๋Š” ์•„๋ž˜์ฒ˜๋Ÿผ forward ๋ฉ”์„œ๋“œ์˜ ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

autocast ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ
Autocast Op Reference ์ผ๋ถ€

 

torch.Autocast() ์‚ฌ์šฉ ์˜ˆ์‹œ

# ํ•™์Šต

import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler

model = YourModel()
optimizer = YourOptimizer(model.parameters(), lr=learning_rate)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

scaler = GradScaler()

for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        # autocast๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Mixed Precision ํ™œ์„ฑํ™”
        with autocast():
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        # GradScaler๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ทธ๋ž˜๋””์–ธํŠธ ์Šค์ผ€์ผ๋ง
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

๋ชจ๋ธ ํ•™์Šต ์‹œ์— mixed precision์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ•™์Šต ๊ณผ์ •์—์„œ ์œ„์™€ ๊ฐ™์ด with autocat()๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•œ๋‹ค

 

 

 

# ์ธํผ๋Ÿฐ์Šค

import torch
from torch.cuda.amp import autocast

model = YourModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()  # ๋ชจ๋ธ์„ evaluation ๋ชจ๋“œ๋กœ ์„ค์ •

with torch.no_grad():
    with autocast():
        inputs = inputs.to(device)
        outputs = model(inputs)

๋ชจ๋ธ ์ธํผ๋Ÿฐ์Šค ์‹œ mixed precision์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์œ„ ์ฝ”๋“œ์ฒ˜๋Ÿผ ์ธํผ๋Ÿฐ์Šค ์ฝ”๋“œ์—์„œ with autocast()๋ฅผ ์‚ฌ์šฉํ•ด๋„ ๋˜๊ณ ,

 

import torch
from torch import nn
from torch.cuda.amp import autocast

class YourModel(nn.Module):
    def __init__(self):
        super(YourModel, self).__init__()
        # ๋ชจ๋ธ์˜ ๋ ˆ์ด์–ด๋“ค์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

    def forward(self, inputs):
        with autocast():
            # forward ๋ฉ”์„œ๋“œ ๋‚ด์˜ ์—ฐ์‚ฐ์„ Mixed Precision์œผ๋กœ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
            x = self.layer1(inputs)
            x = self.layer2(x)
            x = self.layer3(x)
            # ...
            return x

์œ„์™€ ๊ฐ™์ด ๋ชจ๋ธ forward ๋ฉ”์„œ๋“œ์—์„œ with autocast()๋ฅผ ์‚ฌ์šฉํ•ด๋„ ๋œ๋‹ค.

๋ฐ˜์‘ํ˜•