์์ฆ์ ์ด์ง๊ฐํ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ GPU ์์ด ๋๋ฆฌ๊ธฐ ์ด๋ ต์ง๋ง, ๋ ์์ธ๋ก ๊ฐ๋ฒผ์ด ๋ชจ๋ธ๋ค์ CPU ๋ง์ผ๋ก ๋๋ฆด ์ ์๋ค. ๊ฐ๋ฅํ๋ค๋ฉด ํด๋ผ์ฐ๋ ๋น์ฉ๋ ์ค์ผ ์ ์์ผ๋ ์จ๋ผ์ธ ์์ธก์ด ํ์ํ ๊ฒฝ์ฐ๊ฐ ์๋๋ผ๋ฉด CPU ํ๊ฒฝ์์ ์ธํผ๋ฐ์คํ๋ ๊ฒ๋ ๊ณ ๋ คํด ๋ณผ ๋งํ๋ค.
๋ฌผ๋ก CPU๋ก ๋ฅ๋ฌ๋ ๋ชจ๋ธ ์ธํผ๋ฐ์ค๋ฅผ ํ๊ฒ ๋๋ฉด ์๋นํ ๋๋ฆฌ๋ค. ๋๋ฌธ์ ONNX ๋ชจ๋ธ ๋ณํ์ ํ๊ณ , ONNX runtime์ผ๋ก ์ธํผ๋ฐ์ค๋ฅผ ์ํํ๋ฉด ์กฐ๊ธ์ด๋ผ๋ ๋ชจ๋ธ ์ธํผ๋ฐ์ค ์๋๋ฅผ ํฅ์์ํฌ ์ ์๋ค. ๋ํ TensorRT์ ๋ฌ๋ฆฌ ONNX ๋ชจ๋ธ ๋ณํ์ ๊ฒฝ์ฐ ์ ๋ ฅ ํ ์ ํฌ๊ธฐ ๋ํ ๋์ ์ผ๋ก ๊ฐ์ ธ๊ฐ ์ ์๋ค๋ ์ฅ์ ์ด ์๋ค.
๋ฌผ๋ก ํ๋์จ์ด ํ๊ฒฝ์ ๋ฐ๋ผ, ๋ชจ๋ธ์ ๋ฐ๋ผ, ์ ๋ ฅ ํ ์์ ํฌ๊ธฐ์ ๋ฐ๋ผ ์๋ ํฅ์์ ์ ๋๊ฐ ๋ค๋ฅด๊ฑฐ๋, ์คํ๋ ค ์๋๊ฐ ๋๋ ค์ง ์๋ ์์ผ๋ ํ ์คํธ๋ฅผ ํด๋ด์ผ ํ๋ค.
Resnet ์ผ๋ก ๊ฐ๋จํ๊ฒ ํ ์คํธํด๋ดค์ ๋ ์ฝ 1.5~1.7๋ฐฐ ์ ๋์ ์๋ ํฅ์์ด ์์๊ณ , ํ์ฌ ์ฌ์ฉ์ค์ธ CNN ๊ธฐ๋ฐ์ ๊ฒ์ถ๊ธฐ๋ก ํ ์คํธ๋ฅผ ํด๋ดค์ ๋๋ ๋น์ทํ ์ ๋๋ก ์๋๊ฐ ํฅ์๋์๋ค.
์๋๊ฐ ๋ง์ด ๋น ๋ฅผ ํ์ ์๊ณ , ๋ชจ๋ธ์ด ์ด๋์ ๋ ๊ฐ๋ณ๋ค๋ฉด CPU ํ๊ฒฝ์์ ONNX ๋ฐํ์์ผ๋ก ๋ชจ๋ธ์ ๋ฐฐํฌํ๋ ๊ฒ๋ ์ถฉ๋ถํ ์๊ฐํด๋ณผ ์ ์๋ ์ต์ ์ธ ๊ฒ ๊ฐ๋ค.
ONNX Runtime ์์ ์ฝ๋
import torch
import torchvision
import numpy as np
import onnx
import onnxruntime as ort
from onnx import shape_inference
import time
# PyTorch ๋ชจ๋ธ ๋ก๋
torch_model = torchvision.models.resnet18(pretrained=False)
torch_model.eval()
# ์์ ์
๋ ฅ ๋ฐ์ดํฐ ์์ฑ
dummy_input = torch.randn(1, 3, 500, 500, requires_grad=True)
repetitions = 10
for _ in range(5):
_ = torch_model(dummy_input)
start = time.time()
with torch.no_grad():
for rep in range(repetitions):
torch_out = torch_model(dummy_input)
end = time.time()
print('torch ๋ชจ๋ธ ํ๊ท ์์ ์๊ฐ : ', (end-start)/repetitions)
# # ๋ชจ๋ธ ๋ณํ
torch.onnx.export(torch_model, # ์คํ๋ ๋ชจ๋ธ
dummy_input, # ๋ชจ๋ธ ์
๋ ฅ๊ฐ (ํํ ๋๋ ์ฌ๋ฌ ์
๋ ฅ๊ฐ๋ค๋ ๊ฐ๋ฅ)
"test_resnet18.onnx", # ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก (ํ์ผ ๋๋ ํ์ผ๊ณผ ์ ์ฌํ ๊ฐ์ฒด ๋ชจ๋ ๊ฐ๋ฅ)
export_params=True, # ๋ชจ๋ธ ํ์ผ ์์ ํ์ต๋ ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ์ ์ฅํ ์ง์ ์ฌ๋ถ
opset_version=10, # ๋ชจ๋ธ์ ๋ณํํ ๋ ์ฌ์ฉํ ONNX ๋ฒ์
do_constant_folding=True, # ์ต์ ํ์ ์์ํด๋ฉ์ ์ฌ์ฉํ ์ง์ ์ฌ๋ถ
input_names = ['input'], # ๋ชจ๋ธ์ ์
๋ ฅ๊ฐ์ ๊ฐ๋ฆฌํค๋ ์ด๋ฆ
output_names = ['output'], # ๋ชจ๋ธ์ ์ถ๋ ฅ๊ฐ์ ๊ฐ๋ฆฌํค๋ ์ด๋ฆ
dynamic_axes={'input' : {0: 'batch_size', 2: 'height', 3: 'width'}}, # ๊ฐ๋ณ์ ์ธ ๊ธธ์ด๋ฅผ ๊ฐ์ง ์ฐจ์
)
path = "./test_resnet18.onnx"
onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)
# # ONNX ๋ชจ๋ธ ๋ก๋
onnx_model = onnx.load("./test_resnet18.onnx")
onnx.checker.check_model(onnx_model)
# ONNX ๋ฐํ์ ์ธ์
์ด๊ธฐ (CPU ์ฌ์ฉ ์ค์ )
ort_session = ort.InferenceSession("./test_resnet18.onnx", providers=['CPUExecutionProvider'])
print(ort.get_device())
# ์ธํผ๋ฐ์ค ์คํ
ort_inputs = {ort_session.get_inputs()[0].name: np.array(dummy_input.detach())}
for _ in range(5):
_ = ort_session.run(None, ort_inputs)
start = time.time()
with torch.no_grad():
for rep in range(repetitions):
ort_outputs = ort_session.run(None, ort_inputs)
end = time.time()
print('ONNX ํ๊ท ์์ ์๊ฐ : ', (end-start)/repetitions)