Pytorch 2.0
compiled_model = torch.compile(model)
๊ธด ์ค๋ช ํ ๊ฒ ์์ด Pytorch 2.0 ์ดํ compile ์ด๋ผ๋ ๊ฒ์ด ์ถ๊ฐ๋์๋๋ฐ, ์ ์์์ฒ๋ผ torch.comile(model)์ด๋ผ๋ ์งง์ ์ฝ๋ ํ ์ค๋ง ์ถ๊ฐํ๋ฉด ๋ชจ๋ธ ์ธํผ๋ฐ์ค ์๋๋ฅผ ํฅ์์ํฌ ์ ์๋ค๊ณ ํ๋ค.
๊ณต์ ๋ฌธ์์ ๋ฐ๋ฅด๋ฉด A100 GPU์์ ๋ชจ๋ธ ํ์ต ์๋๋ 43% ํฅ์๋๊ณ , ๋ชจ๋ธ ์ธํผ๋ฐ์ค ์๋๋ Float32 precision์์ 21%, AMP precision์์ 51% ์ ๋ ํฅ์๋๋ค๊ณ ํ๋ค.
Pytorch 2.0 Compile ๋ชจ๋ธ ์ธํผ๋ฐ์ค ํ ์คํธ
torchvision์์ ์ ๊ณตํ๋ ๊ธฐ๋ณธ์ ์ธ ๋ชจ๋ธ์ธ resnet50์ผ๋ก pytorch compile์ ์๋ ๊ฐ์ ํ ์คํธ๋ฅผ ์งํํด๋ดค๋ค.
# ํ ์คํธ ์ฝ๋
import torch
import torchvision.models
import numpy as np
model = torchvision.models.resnet50()
model.eval().cuda()
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 10
timings=np.zeros((repetitions,1))
B = 1
H = 1000
W = 1000
input_data = torch.randn((B, 3, H, W)).float().cuda()
for _ in range(5):
_ = model(input_data)
torch.cuda.synchronize()
with torch.no_grad():
for rep in range(repetitions):
starter.record()
torch_out = model(input_data)
ender.record()
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
timings[rep] = curr_time
# new_H = torch.randint(H-100,H+100,(1,))
# new_W = torch.randint(W-100,W+100,(1,))
# input_data = torch.randn((1, 3, new_H, new_W)).float().cuda()
print('torch ๋ชจ๋ธ ํ๊ท ์์ ์๊ฐ : ', np.mean(np.array(timings)))
# model_compiled = torch.compile(model, dynamic=True)
model_compiled = torch.compile(model)
del model
for _ in range(5):
_ = model_compiled(input_data)
torch.cuda.synchronize()
timings=np.zeros((repetitions,1))
with torch.no_grad():
for rep in range(repetitions):
starter.record()
complied_out = model_compiled(input_data)
ender.record()
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
timings[rep] = curr_time
# new_H = torch.randint(H-100,H+100,(1,))
# new_W = torch.randint(W-100,W+100,(1,))
# input_data = torch.randn((1, 3, new_H, new_W)).float().cuda()
print('compiled ๋ชจ๋ธ ํ๊ท ์์ ์๊ฐ : ', np.mean(np.array(timings)))
# ๊ฒฐ๊ณผ ๋น๊ต
error = torch.abs(torch_out - complied_out).mean()
print(f"Mean Absolute Error: {error.item()}")
- ํ ์คํธ ๋ชจ๋ธ : resnet50
- ํ
์คํธ ๋ฐฉ๋ฒ
- pytorch ๋ชจ๋ธ๊ณผ compileํ ๋ชจ๋ธ์ ์์ฑํ ํ ๋ช ๊ฐ์ง ์ ๋ ฅ ํฌ๊ธฐ์ ๋ํด ์๋ ํ ์คํธ
- ์ ๋ ฅ ํ ์๋ฅผ ๋ชจ๋ธ์ 10ํ ๋ฐ๋ณตํด์ ์ ๋ ฅํ ํ ํ๊ท ์์ ์๊ฐ์ ์ธก์
- ์
๋ ฅ ํฌ๊ธฐ
- ์ ์ ์ธ ์ ๋ ฅ ํฌ๊ธฐ๋ก ํ ์คํธ
- torch.compile()์ ํ๋ผ๋ฏธํฐ์์ dynamic=True๋ก ์ค์ ํ๋ฉด dynamic input shape์ ๋์์ด ๊ฐ๋ฅํ๋ค ํ์ฌ ํ ์คํธ
# ํ ์คํธ ๊ฒฐ๊ณผ
Input shape | ์ธํผ๋ฐ์ค ํ๊ท ์์ ์๊ฐ (ms) | ์๋ ํฅ์ (%) | |
Pytorch model | Compiled model | ||
[1,3,500,500] | 9.78 | 14.56 | -32.83 |
[10,3,500,500] | 44.78 | 38.15 | 17.38 |
[1,3,1000,1000] | 20.46 | 17.88 | 14.43 |
[8,3,1000,1000] | 140.84 | 113.8 | 23.76 |
[1,3,900~1100,900~1100] (Dynamic shape) |
162.26 | 8513 | -98.09 |
- ์์ธ์ง [1,3,500,500] ํฌ๊ธฐ์ ์์ ์ ๋ ฅ ํฌ๊ธฐ์์๋ compile ๋ชจ๋ธ์ด ์คํ๋ ค ์๋ ๊ฐ์
- ์ ๋ ฅ ํฌ๊ธฐ๋ ์๋๋ผ๋ ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ํค์ฐ๋ฉด ์๋ ํฅ์
- ์
๋ ฅ ํฌ๊ธฐ๋ฅผ ํค์ฐ๋ฉด ๋ฐฐ์น=1 ์์๋ ์๋ ํฅ์๋๊ณ , ๋ฐฐ์น๊ฐ ์ปค์ง์๋ก ์๋ ํฅ์์ด ๋๋๋ฌ์ง
- ์ ๋ฐ์ ์ผ๋ก 15~20%์ ์๋ ํฅ์ (๊ณต์ ๋ฌธ์์ ๋น์ทํ ์์น)
- ๋์ ์ธ ์
๋ ฅ์์์ ํ
์คํธ๋... ์๋๊ฐ ์คํ๋ ค ๋งค์ฐ ๋๋ ค์ง
- ๊ณต์ ๋ฌธ์์ ๋ฐ๋ฅด๋ฉด ํ์ฌ๋ Danamic shape ์ ๋ ฅ์ ๋ํ ์ง์์ด ์ ํ๋์ด ์๋ค๊ณ ํจ
→ Pytorch Compile ์ฌ์ฉ ์ ์ฌ์ฉํ๋ ๋ชจ๋ธ๊ณผ ์ ๋ ฅ ํฌ๊ธฐ์ ๋ฐ๋ผ ์๋ ๊ฐ์ ์ ๋๊ฐ ๋ค๋ฅผ ์ ์์ผ๋ ํ ์คํธ๋ฅผ ํ ํ์ ์ฌ์ฉํ๋ ๊ฒ์ด ์ข์ ๊ฒ ๊ฐ๋ค.
→ ํ์ต ์์๋ ๋ฐฐ์น๊ฐ ์ ๋ ฅ ํฌ๊ธฐ๊ฐ ๊ณ ์ ๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ผ๋ ์๊ด์์ง๋ง, ์ธํผ๋ฐ์ค ์ ๋์ ์ธ ์ ๋ ฅ์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ์๋ ์ฌ์ฉ์ด ์ ํ๋ ์ ์๋ค.