๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[ONNX] ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ONNX Runtime์œผ๋กœ CPU ํ™˜๊ฒฝ์—์„œ ๊ฐ€์†ํ™”ํ•˜๊ธฐ

by ๋ญ…์ฆค 2023. 11. 16.
๋ฐ˜์‘ํ˜•

์š”์ฆ˜์€ ์–ด์ง€๊ฐ„ํ•œ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ GPU ์—†์ด ๋Œ๋ฆฌ๊ธฐ ์–ด๋ ต์ง€๋งŒ, ๋˜ ์˜์™ธ๋กœ ๊ฐ€๋ฒผ์šด ๋ชจ๋ธ๋“ค์€ CPU ๋งŒ์œผ๋กœ ๋Œ๋ฆด ์ˆ˜ ์žˆ๋‹ค. ๊ฐ€๋Šฅํ•˜๋‹ค๋ฉด ํด๋ผ์šฐ๋“œ ๋น„์šฉ๋„ ์ค„์ผ ์ˆ˜ ์žˆ์œผ๋‹ˆ ์˜จ๋ผ์ธ ์˜ˆ์ธก์ด ํ•„์š”ํ•œ ๊ฒฝ์šฐ๊ฐ€ ์•„๋‹ˆ๋ผ๋ฉด CPU ํ™˜๊ฒฝ์—์„œ ์ธํผ๋Ÿฐ์Šคํ•˜๋Š” ๊ฒƒ๋„ ๊ณ ๋ คํ•ด ๋ณผ ๋งŒํ•˜๋‹ค.

 

๋ฌผ๋ก  CPU๋กœ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ์ธํผ๋Ÿฐ์Šค๋ฅผ ํ•˜๊ฒŒ ๋˜๋ฉด ์ƒ๋‹นํžˆ ๋Š๋ฆฌ๋‹ค. ๋•Œ๋ฌธ์— ONNX ๋ชจ๋ธ ๋ณ€ํ™˜์„ ํ•˜๊ณ , ONNX runtime์œผ๋กœ ์ธํผ๋Ÿฐ์Šค๋ฅผ ์ˆ˜ํ–‰ํ•˜๋ฉด ์กฐ๊ธˆ์ด๋ผ๋„ ๋ชจ๋ธ ์ธํผ๋Ÿฐ์Šค ์†๋„๋ฅผ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ๋‹ค. ๋˜ํ•œ TensorRT์™€ ๋‹ฌ๋ฆฌ ONNX ๋ชจ๋ธ ๋ณ€ํ™˜์˜ ๊ฒฝ์šฐ ์ž…๋ ฅ ํ…์„œ ํฌ๊ธฐ ๋˜ํ•œ ๋™์ ์œผ๋กœ ๊ฐ€์ ธ๊ฐˆ ์ˆ˜ ์žˆ๋‹ค๋Š” ์žฅ์ ์ด ์žˆ๋‹ค.

 

๋ฌผ๋ก  ํ•˜๋“œ์›จ์–ด ํ™˜๊ฒฝ์— ๋”ฐ๋ผ, ๋ชจ๋ธ์— ๋”ฐ๋ผ, ์ž…๋ ฅ ํ…์„œ์˜ ํฌ๊ธฐ์— ๋”ฐ๋ผ ์†๋„ ํ–ฅ์ƒ์˜ ์ •๋„๊ฐ€ ๋‹ค๋ฅด๊ฑฐ๋‚˜, ์˜คํžˆ๋ ค ์†๋„๊ฐ€ ๋Š๋ ค์งˆ ์ˆ˜๋„ ์žˆ์œผ๋‹ˆ ํ…Œ์ŠคํŠธ๋ฅผ ํ•ด๋ด์•ผ ํ•œ๋‹ค.

 

Resnet ์œผ๋กœ ๊ฐ„๋‹จํ•˜๊ฒŒ ํ…Œ์ŠคํŠธํ•ด๋ดค์„ ๋•Œ ์•ฝ 1.5~1.7๋ฐฐ ์ •๋„์˜ ์†๋„ ํ–ฅ์ƒ์ด ์žˆ์—ˆ๊ณ , ํ˜„์žฌ ์‚ฌ์šฉ์ค‘์ธ CNN ๊ธฐ๋ฐ˜์˜ ๊ฒ€์ถœ๊ธฐ๋กœ ํ…Œ์ŠคํŠธ๋ฅผ ํ•ด๋ดค์„ ๋•Œ๋„ ๋น„์Šทํ•œ ์ •๋„๋กœ ์†๋„๊ฐ€ ํ–ฅ์ƒ๋˜์—ˆ๋‹ค.

 

์†๋„๊ฐ€ ๋งŽ์ด ๋น ๋ฅผ ํ•„์š” ์—†๊ณ , ๋ชจ๋ธ์ด ์–ด๋Š์ •๋„ ๊ฐ€๋ณ๋‹ค๋ฉด CPU ํ™˜๊ฒฝ์—์„œ ONNX ๋Ÿฐํƒ€์ž„์œผ๋กœ ๋ชจ๋ธ์„ ๋ฐฐํฌํ•˜๋Š” ๊ฒƒ๋„ ์ถฉ๋ถ„ํžˆ ์ƒ๊ฐํ•ด๋ณผ ์ˆ˜ ์žˆ๋Š” ์˜ต์…˜์ธ ๊ฒƒ ๊ฐ™๋‹ค.

 

ONNX Runtime ์˜ˆ์ œ ์ฝ”๋“œ

import torch
import torchvision
import numpy as np
import onnx
import onnxruntime as ort
from onnx import shape_inference
import time


# PyTorch ๋ชจ๋ธ ๋กœ๋“œ
torch_model = torchvision.models.resnet18(pretrained=False)
torch_model.eval()

# ์˜ˆ์ œ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
dummy_input = torch.randn(1, 3, 500, 500, requires_grad=True)

repetitions = 10

for _ in range(5):
    _ = torch_model(dummy_input)

start = time.time()
with torch.no_grad():
    for rep in range(repetitions):
        torch_out = torch_model(dummy_input)
end = time.time()

print('torch ๋ชจ๋ธ ํ‰๊ท  ์†Œ์š” ์‹œ๊ฐ„ : ', (end-start)/repetitions)
    

# # ๋ชจ๋ธ ๋ณ€ํ™˜
torch.onnx.export(torch_model,               # ์‹คํ–‰๋  ๋ชจ๋ธ
                    dummy_input,                         # ๋ชจ๋ธ ์ž…๋ ฅ๊ฐ’ (ํŠœํ”Œ ๋˜๋Š” ์—ฌ๋Ÿฌ ์ž…๋ ฅ๊ฐ’๋“ค๋„ ๊ฐ€๋Šฅ)
                    "test_resnet18.onnx",   # ๋ชจ๋ธ ์ €์žฅ ๊ฒฝ๋กœ (ํŒŒ์ผ ๋˜๋Š” ํŒŒ์ผ๊ณผ ์œ ์‚ฌํ•œ ๊ฐ์ฒด ๋ชจ๋‘ ๊ฐ€๋Šฅ)
                    export_params=True,        # ๋ชจ๋ธ ํŒŒ์ผ ์•ˆ์— ํ•™์Šต๋œ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋ฅผ ์ €์žฅํ• ์ง€์˜ ์—ฌ๋ถ€
                    opset_version=10,          # ๋ชจ๋ธ์„ ๋ณ€ํ™˜ํ•  ๋•Œ ์‚ฌ์šฉํ•  ONNX ๋ฒ„์ „
                    do_constant_folding=True,  # ์ตœ์ ํ™”์‹œ ์ƒ์ˆ˜ํด๋”ฉ์„ ์‚ฌ์šฉํ• ์ง€์˜ ์—ฌ๋ถ€
                    input_names = ['input'],   # ๋ชจ๋ธ์˜ ์ž…๋ ฅ๊ฐ’์„ ๊ฐ€๋ฆฌํ‚ค๋Š” ์ด๋ฆ„
                    output_names = ['output'], # ๋ชจ๋ธ์˜ ์ถœ๋ ฅ๊ฐ’์„ ๊ฐ€๋ฆฌํ‚ค๋Š” ์ด๋ฆ„
                    dynamic_axes={'input' : {0: 'batch_size', 2: 'height', 3: 'width'}},    # ๊ฐ€๋ณ€์ ์ธ ๊ธธ์ด๋ฅผ ๊ฐ€์ง„ ์ฐจ์›
                    )


path = "./test_resnet18.onnx"
onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)

# # ONNX ๋ชจ๋ธ ๋กœ๋“œ
onnx_model = onnx.load("./test_resnet18.onnx")
onnx.checker.check_model(onnx_model)

# ONNX ๋Ÿฐํƒ€์ž„ ์„ธ์…˜ ์—ด๊ธฐ (CPU ์‚ฌ์šฉ ์„ค์ •)
ort_session = ort.InferenceSession("./test_resnet18.onnx", providers=['CPUExecutionProvider'])
print(ort.get_device())

# ์ธํผ๋Ÿฐ์Šค ์‹คํ–‰
ort_inputs = {ort_session.get_inputs()[0].name: np.array(dummy_input.detach())}

for _ in range(5):
    _ = ort_session.run(None, ort_inputs)

start = time.time()
with torch.no_grad():
    for rep in range(repetitions):
        ort_outputs = ort_session.run(None, ort_inputs)
end = time.time()

print('ONNX ํ‰๊ท  ์†Œ์š” ์‹œ๊ฐ„ : ', (end-start)/repetitions)
๋ฐ˜์‘ํ˜•