[pytorch] Mixed Precision ์‚ฌ์šฉ ๋ฐฉ๋ฒ• | torch.amp | torch.autocast | ๋ชจ๋ธ ํ•™์Šต ์†๋„๋ฅผ ๋†’์ด๊ณ  ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํšจ์œจ์ ์œผ๋กœ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•

2022. 12. 20. 19:29ยท๐Ÿ’ป Programming/AI & ML
๋ฐ˜์‘ํ˜•
Mixed Precision

์ผ๋ฐ˜์ ์ธ neural network์—์„œ๋Š” 32-bit floating point(FP32) precision์„ ์ด์šฉํ•˜์—ฌ ํ•™์Šต์„ ์‹œํ‚ค๋Š”๋ฐ, ์ตœ์‹  ํ•˜๋“œ์›จ์–ด์—์„œ๋Š” lower precision(FP16) ๊ณ„์‚ฐ์ด ์ง€์›๋˜๋ฉด์„œ ์†๋„์—์„œ ์ด์ ์„ ์–ป์„ ์ˆ˜ ์žˆ๋‹ค. ํ•˜์ง€๋งŒ FP16์œผ๋กœ precision์„ ์ค„์ด๋ฉด ์ˆ˜๋ฅผ ํ‘œํ˜„ํ•˜๋Š” ๋ฒ”์œ„๊ฐ€ ์ค„์–ด๋“ค์–ด ํ•™์Šต ์„ฑ๋Šฅ์ด ์ €ํ•˜๋  ์ˆ˜ ์žˆ๋‹ค.

 

Mixed Precision์€ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ํ•™์Šต ๊ณผ์ •์—์„œ ๋ถ€๋™์†Œ์ˆ˜์  ์—ฐ์‚ฐ์˜ ์ •๋ฐ€๋„๋ฅผ ํ˜ผํ•ฉํ•˜์—ฌ ์‚ฌ์šฉํ•˜๋Š” ๊ธฐ์ˆ ๋กœ, ํ•™์Šต ์†๋„๋ฅผ ๋†’์ด๊ณ  ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ด๋Š” ๋ฐ ๋„์›€์„ ์ค€๋‹ค. Mixed Precision์€ ๋Œ€๊ฐœ FP32(32๋น„ํŠธ ๋ถ€๋™์†Œ์ˆ˜์ )์™€ FP16(16๋น„ํŠธ ๋ถ€๋™์†Œ์ˆ˜์ )์„ ์กฐํ•ฉํ•˜์—ฌ ์‚ฌ์šฉํ•˜๋ฉฐ, ๊ฐ€์ค‘์น˜์™€ ๊ทธ๋ž˜๋””์–ธํŠธ๋Š” FP16๋กœ ์ €์žฅํ•˜๊ณ  ์—ฐ์‚ฐ์€ FP16๋กœ ์ˆ˜ํ–‰ํ•˜๋ฉด์„œ ์ผ๋ถ€ ์—ฐ์‚ฐ์—์„œ๋Š” FP32๋กœ ์ „ํ™˜ํ•˜์—ฌ ์˜ค๋ฒ„ํ”Œ๋กœ์šฐ ๋ฐ ์–ธ๋”ํ”Œ๋กœ์šฐ๋ฅผ ๋ฐฉ์ง€ํ•œ๋‹ค.

 

Mixed Precision์„ ์‚ฌ์šฉํ•˜๋ฉด ์ผ๋ฐ˜์ ์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์ ˆ์•ฝํ•˜๊ณ  ํ•™์Šต ์†๋„๋ฅผ ํ–ฅ์ƒ์‹œํ‚ค๋ฉด์„œ ์ •ํ™•๋„๋Š” ์–ด๋А ์ •๋„ ์œ ์ง€ํ•  ์ˆ˜ ์žˆ๊ธฐ์— ๋ชจ๋ธ ํ•™์Šต ์‹œ ์ž์ฃผ ์‚ฌ์šฉ๋œ๋‹ค.

 

AUTOMATIC MIXED PRECISION PACKAGE - TORCH.AMP

torch.amp์˜ torch.autocast๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด pytorch ๋ชจ๋ธ์— mixed precision์„ ๊ฐ„๋‹จํ•˜๊ฒŒ ์ ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

 

Docs url :https://pytorch.org/docs/stable/amp.html

torch.autocast

์œ„ ์„ค๋ช…์—์„œ, ops๋Š” ์ •ํ™•๋„๋ฅผ ์œ ์ง€ํ•˜๋ฉด์„œ ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œํ‚ค๊ธฐ ์œ„ํ•ด autocast์—์„œ ์„ ํƒํ•œ ํŠน์ • op dtype์—์„œ๋งŒ ์‹คํ–‰๋œ๋‹ค๊ณ  ๋‚˜์™€ ์žˆ๋‹ค. Autocast Op Reference๋ฅผ ๋ณด๋ฉด ์ž์„ธํ•œ ์„ค๋ช…์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

 

๋˜ํ•œ autocast๋Š” ์•„๋ž˜์ฒ˜๋Ÿผ forward ๋ฉ”์„œ๋“œ์˜ ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ๋กœ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

autocast ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ
Autocast Op Reference ์ผ๋ถ€

 

torch.Autocast() ์‚ฌ์šฉ ์˜ˆ์‹œ

# ํ•™์Šต

import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler

model = YourModel()
optimizer = YourOptimizer(model.parameters(), lr=learning_rate)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

scaler = GradScaler()

for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        # autocast๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Mixed Precision ํ™œ์„ฑํ™”
        with autocast():
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        # GradScaler๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ทธ๋ž˜๋””์–ธํŠธ ์Šค์ผ€์ผ๋ง
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

๋ชจ๋ธ ํ•™์Šต ์‹œ์— mixed precision์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ํ•™์Šต ๊ณผ์ •์—์„œ ์œ„์™€ ๊ฐ™์ด with autocat()๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•œ๋‹ค

 

 

 

# ์ธํผ๋Ÿฐ์Šค

import torch
from torch.cuda.amp import autocast

model = YourModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()  # ๋ชจ๋ธ์„ evaluation ๋ชจ๋“œ๋กœ ์„ค์ •

with torch.no_grad():
    with autocast():
        inputs = inputs.to(device)
        outputs = model(inputs)

๋ชจ๋ธ ์ธํผ๋Ÿฐ์Šค ์‹œ mixed precision์„ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์œ„ ์ฝ”๋“œ์ฒ˜๋Ÿผ ์ธํผ๋Ÿฐ์Šค ์ฝ”๋“œ์—์„œ with autocast()๋ฅผ ์‚ฌ์šฉํ•ด๋„ ๋˜๊ณ ,

 

import torch
from torch import nn
from torch.cuda.amp import autocast

class YourModel(nn.Module):
    def __init__(self):
        super(YourModel, self).__init__()
        # ๋ชจ๋ธ์˜ ๋ ˆ์ด์–ด๋“ค์„ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

    def forward(self, inputs):
        with autocast():
            # forward ๋ฉ”์„œ๋“œ ๋‚ด์˜ ์—ฐ์‚ฐ์„ Mixed Precision์œผ๋กœ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
            x = self.layer1(inputs)
            x = self.layer2(x)
            x = self.layer3(x)
            # ...
            return x

์œ„์™€ ๊ฐ™์ด ๋ชจ๋ธ forward ๋ฉ”์„œ๋“œ์—์„œ with autocast()๋ฅผ ์‚ฌ์šฉํ•ด๋„ ๋œ๋‹ค.

๋ฐ˜์‘ํ˜•

'๐Ÿ’ป Programming > AI & ML' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€

[HuggingFace] Swin Transformer ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ํ•™์Šต ํŠœํ† ๋ฆฌ์–ผ  (0) 2023.01.11
[ONNX] pytorch ๋ชจ๋ธ์„ ONNX๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ์‹คํ–‰ํ•˜๊ธฐ  (0) 2022.12.21
[pytorch] pytorch ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ Missing key(s) in state_dict ์—๋Ÿฌ  (0) 2022.12.15
[pytorch] COCO Data Format ์ „์šฉ Custom Dataset ์ƒ์„ฑ  (1) 2022.06.04
[pytorch] model ์— ์ ‘๊ทผํ•˜๊ธฐ, ํŠน์ • layer ๋ณ€๊ฒฝํ•˜๊ธฐ  (0) 2022.01.05
'๐Ÿ’ป Programming/AI & ML' ์นดํ…Œ๊ณ ๋ฆฌ์˜ ๋‹ค๋ฅธ ๊ธ€
  • [HuggingFace] Swin Transformer ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ ํ•™์Šต ํŠœํ† ๋ฆฌ์–ผ
  • [ONNX] pytorch ๋ชจ๋ธ์„ ONNX๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ์‹คํ–‰ํ•˜๊ธฐ
  • [pytorch] pytorch ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ Missing key(s) in state_dict ์—๋Ÿฌ
  • [pytorch] COCO Data Format ์ „์šฉ Custom Dataset ์ƒ์„ฑ
๋ญ…์ฆค
๋ญ…์ฆค
AI ๊ธฐ์ˆ  ๋ธ”๋กœ๊ทธ
    ๋ฐ˜์‘ํ˜•
  • ๋ญ…์ฆค
    CV DOODLE
    ๋ญ…์ฆค
  • ์ „์ฒด
    ์˜ค๋Š˜
    ์–ด์ œ
  • ๊ณต์ง€์‚ฌํ•ญ

    • โœจ About Me
    • ๋ถ„๋ฅ˜ ์ „์ฒด๋ณด๊ธฐ (199)
      • ๐Ÿ“– Fundamentals (33)
        • Computer Vision (9)
        • 3D vision & Graphics (6)
        • AI & ML (15)
        • NLP (2)
        • etc. (1)
      • ๐Ÿ› Research (64)
        • Deep Learning (7)
        • Image Classification (2)
        • Detection & Segmentation (17)
        • OCR (7)
        • Multi-modal (4)
        • Generative AI (6)
        • 3D Vision (2)
        • Material & Texture Recognit.. (8)
        • NLP & LLM (11)
        • etc. (0)
      • ๐ŸŒŸ AI & ML Tech (7)
        • AI & ML ์ธ์‚ฌ์ดํŠธ (7)
      • ๐Ÿ’ป Programming (86)
        • Python (18)
        • Computer Vision (12)
        • LLM (4)
        • AI & ML (18)
        • Database (3)
        • Apache Airflow (6)
        • Docker & Kubernetes (14)
        • ์ฝ”๋”ฉ ํ…Œ์ŠคํŠธ (4)
        • C++ (1)
        • etc. (6)
      • ๐Ÿ’ฌ ETC (3)
        • ์ฑ… ๋ฆฌ๋ทฐ (3)
  • ๋งํฌ

  • ์ธ๊ธฐ ๊ธ€

  • ํƒœ๊ทธ

    GPT
    AI
    ๊ฐ์ฒด ๊ฒ€์ถœ
    nlp
    OpenCV
    ๊ฐ์ฒด๊ฒ€์ถœ
    ChatGPT
    deep learning
    material recognition
    airflow
    ๋„์ปค
    multi-modal
    Image Classification
    Text recognition
    segmentation
    OpenAI
    ๋”ฅ๋Ÿฌ๋‹
    LLM
    pytorch
    Python
    VLP
    CNN
    ํ”„๋กฌํ”„ํŠธ์—”์ง€๋‹ˆ์–ด๋ง
    Computer Vision
    object detection
    3D Vision
    ํŒŒ์ด์ฌ
    pandas
    ์ปดํ“จํ„ฐ๋น„์ „
    OCR
  • ์ตœ๊ทผ ๋Œ“๊ธ€

  • ์ตœ๊ทผ ๊ธ€

  • hELLOยท Designed By์ •์ƒ์šฐ.v4.10.3
๋ญ…์ฆค
[pytorch] Mixed Precision ์‚ฌ์šฉ ๋ฐฉ๋ฒ• | torch.amp | torch.autocast | ๋ชจ๋ธ ํ•™์Šต ์†๋„๋ฅผ ๋†’์ด๊ณ  ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํšจ์œจ์ ์œผ๋กœ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•
์ƒ๋‹จ์œผ๋กœ

ํ‹ฐ์Šคํ† ๋ฆฌํˆด๋ฐ”