๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[Model Inference] Pytorch 2.0 Compile ์‚ฌ์šฉ ํ›„๊ธฐ ๋ฐ ์žฅ๋‹จ์  | pytorch compile ๋ชจ๋ธ ์ถ”๋ก  ์†๋„ ๊ฐœ์„  ํ…Œ์ŠคํŠธ

by ๋ญ…์ฆค 2023. 10. 7.
๋ฐ˜์‘ํ˜•

Pytorch Compile ํŠœํ† ๋ฆฌ์–ผ

Pytorch 2.0 Overview

 

Pytorch 2.0

Speedups for torch.compile against eager mode on an NVIDIA A100 GPU

 

compiled_model = torch.compile(model)

๊ธด ์„ค๋ช…ํ•  ๊ฒƒ ์—†์ด Pytorch 2.0 ์ดํ›„ compile ์ด๋ผ๋Š” ๊ฒƒ์ด ์ถ”๊ฐ€๋˜์—ˆ๋Š”๋ฐ, ์œ„ ์˜ˆ์‹œ์ฒ˜๋Ÿผ torch.comile(model)์ด๋ผ๋Š” ์งง์€ ์ฝ”๋“œ ํ•œ ์ค„๋งŒ ์ถ”๊ฐ€ํ•˜๋ฉด ๋ชจ๋ธ ์ธํผ๋Ÿฐ์Šค ์†๋„๋ฅผ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ๋‹ค๊ณ  ํ•œ๋‹ค.

 

๊ณต์‹ ๋ฌธ์„œ์— ๋”ฐ๋ฅด๋ฉด A100 GPU์—์„œ ๋ชจ๋ธ ํ•™์Šต ์†๋„๋Š” 43% ํ–ฅ์ƒ๋˜๊ณ , ๋ชจ๋ธ ์ธํผ๋Ÿฐ์Šค ์†๋„๋Š” Float32 precision์—์„œ 21%, AMP precision์—์„œ 51% ์ •๋„ ํ–ฅ์ƒ๋œ๋‹ค๊ณ  ํ•œ๋‹ค.

 

 

Pytorch 2.0 Compile ๋ชจ๋ธ ์ธํผ๋Ÿฐ์Šค ํ…Œ์ŠคํŠธ

torchvision์—์„œ ์ œ๊ณตํ•˜๋Š” ๊ธฐ๋ณธ์ ์ธ ๋ชจ๋ธ์ธ resnet50์œผ๋กœ pytorch compile์˜ ์†๋„ ๊ฐœ์„  ํ…Œ์ŠคํŠธ๋ฅผ ์ง„ํ–‰ํ•ด๋ดค๋‹ค.

 

# ํ…Œ์ŠคํŠธ ์ฝ”๋“œ

import torch
import torchvision.models
import numpy as np

model = torchvision.models.resnet50()
model.eval().cuda()

starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 10
timings=np.zeros((repetitions,1))

B = 1
H = 1000
W = 1000
input_data = torch.randn((B, 3, H, W)).float().cuda()

for _ in range(5):
    _ = model(input_data)
torch.cuda.synchronize()

with torch.no_grad():
    for rep in range(repetitions):
        starter.record()
        torch_out = model(input_data)
        ender.record()
        torch.cuda.synchronize()
        curr_time = starter.elapsed_time(ender)
        timings[rep] = curr_time
        # new_H = torch.randint(H-100,H+100,(1,))
        # new_W = torch.randint(W-100,W+100,(1,))
        # input_data = torch.randn((1, 3, new_H, new_W)).float().cuda()

print('torch ๋ชจ๋ธ ํ‰๊ท  ์†Œ์š” ์‹œ๊ฐ„ : ', np.mean(np.array(timings)))

# model_compiled = torch.compile(model, dynamic=True)
model_compiled = torch.compile(model)
del model

for _ in range(5):
    _ = model_compiled(input_data)
torch.cuda.synchronize()


timings=np.zeros((repetitions,1))
with torch.no_grad():
    for rep in range(repetitions):
        starter.record()
        complied_out = model_compiled(input_data)
        ender.record()
        torch.cuda.synchronize()
        curr_time = starter.elapsed_time(ender)
        timings[rep] = curr_time
        # new_H = torch.randint(H-100,H+100,(1,))
        # new_W = torch.randint(W-100,W+100,(1,))
        # input_data = torch.randn((1, 3, new_H, new_W)).float().cuda()

print('compiled ๋ชจ๋ธ ํ‰๊ท  ์†Œ์š” ์‹œ๊ฐ„ : ', np.mean(np.array(timings)))

# ๊ฒฐ๊ณผ ๋น„๊ต
error = torch.abs(torch_out - complied_out).mean()
print(f"Mean Absolute Error: {error.item()}")
  • ํ…Œ์ŠคํŠธ ๋ชจ๋ธ : resnet50
  • ํ…Œ์ŠคํŠธ ๋ฐฉ๋ฒ•
    • pytorch ๋ชจ๋ธ๊ณผ compileํ•œ ๋ชจ๋ธ์„ ์ƒ์„ฑํ•œ ํ›„ ๋ช‡ ๊ฐ€์ง€ ์ž…๋ ฅ ํฌ๊ธฐ์— ๋Œ€ํ•ด ์†๋„ ํ…Œ์ŠคํŠธ
    • ์ž…๋ ฅ ํ…์„œ๋ฅผ ๋ชจ๋ธ์— 10ํšŒ ๋ฐ˜๋ณตํ•ด์„œ ์ž…๋ ฅํ•œ ํ›„ ํ‰๊ท  ์†Œ์š” ์‹œ๊ฐ„์„ ์ธก์ •
  • ์ž…๋ ฅ ํฌ๊ธฐ
    • ์ •์ ์ธ ์ž…๋ ฅ ํฌ๊ธฐ๋กœ ํ…Œ์ŠคํŠธ
    • torch.compile()์˜ ํŒŒ๋ผ๋ฏธํ„ฐ์—์„œ dynamic=True๋กœ ์„ค์ •ํ•˜๋ฉด dynamic input shape์— ๋Œ€์‘์ด ๊ฐ€๋Šฅํ•˜๋‹ค ํ•˜์—ฌ ํ…Œ์ŠคํŠธ

 

 

# ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ

 Input shape ์ธํผ๋Ÿฐ์Šค ํ‰๊ท  ์†Œ์š” ์‹œ๊ฐ„ (ms) ์†๋„ ํ–ฅ์ƒ (%)
Pytorch model Compiled model
[1,3,500,500] 9.78 14.56 -32.83
[10,3,500,500] 44.78 38.15 17.38
[1,3,1000,1000] 20.46 17.88 14.43
[8,3,1000,1000] 140.84 113.8 23.76
[1,3,900~1100,900~1100]
(Dynamic shape)
162.26 8513 -98.09
  • ์™œ์ธ์ง€ [1,3,500,500] ํฌ๊ธฐ์˜ ์ž‘์€ ์ž…๋ ฅ ํฌ๊ธฐ์—์„œ๋Š” compile ๋ชจ๋ธ์ด ์˜คํžˆ๋ ค ์†๋„ ๊ฐ์†Œ
  • ์ž…๋ ฅ ํฌ๊ธฐ๋Š” ์ž‘๋”๋ผ๋„ ๋ฐฐ์น˜ ํฌ๊ธฐ๋ฅผ ํ‚ค์šฐ๋ฉด ์†๋„ ํ–ฅ์ƒ 
  • ์ž…๋ ฅ ํฌ๊ธฐ๋ฅผ ํ‚ค์šฐ๋ฉด ๋ฐฐ์น˜=1 ์—์„œ๋„ ์†๋„ ํ–ฅ์ƒ๋˜๊ณ , ๋ฐฐ์น˜๊ฐ€ ์ปค์งˆ์ˆ˜๋ก ์†๋„ ํ–ฅ์ƒ์ด ๋‘๋“œ๋Ÿฌ์ง
    • ์ „๋ฐ˜์ ์œผ๋กœ 15~20%์˜ ์†๋„ ํ–ฅ์ƒ (๊ณต์‹ ๋ฌธ์„œ์™€ ๋น„์Šทํ•œ ์ˆ˜์น˜)
  • ๋™์ ์ธ ์ž…๋ ฅ์—์„œ์˜ ํ…Œ์ŠคํŠธ๋Š”... ์†๋„๊ฐ€ ์˜คํžˆ๋ ค ๋งค์šฐ ๋Š๋ ค์ง
    • ๊ณต์‹ ๋ฌธ์„œ์— ๋”ฐ๋ฅด๋ฉด ํ˜„์žฌ๋Š” Danamic shape ์ž…๋ ฅ์— ๋Œ€ํ•œ ์ง€์›์ด ์ œํ•œ๋˜์–ด ์žˆ๋‹ค๊ณ  ํ•จ

→ Pytorch Compile ์‚ฌ์šฉ ์‹œ ์‚ฌ์šฉํ•˜๋Š” ๋ชจ๋ธ๊ณผ ์ž…๋ ฅ ํฌ๊ธฐ์— ๋”ฐ๋ผ ์†๋„ ๊ฐœ์„  ์ •๋„๊ฐ€ ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์œผ๋‹ˆ ํ…Œ์ŠคํŠธ๋ฅผ ํ•œ ํ›„์— ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ข‹์„ ๊ฒƒ ๊ฐ™๋‹ค.

→ ํ•™์Šต ์‹œ์—๋Š” ๋ฐฐ์น˜๊ฐ€ ์ž…๋ ฅ ํฌ๊ธฐ๊ฐ€ ๊ณ ์ •๋œ ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์œผ๋‹ˆ ์ƒ๊ด€์—†์ง€๋งŒ, ์ธํผ๋Ÿฐ์Šค ์‹œ ๋™์ ์ธ ์ž…๋ ฅ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ์—๋Š” ์‚ฌ์šฉ์ด ์ œํ•œ๋  ์ˆ˜ ์žˆ๋‹ค.

 

๋ฐ˜์‘ํ˜•