ONNX (Open Neural Network eXchange)
ONNX๋ ๊ธฐ๊ณ ํ์ต ๋ชจ๋ธ์ ํํํ๊ธฐ ์ํด ๋ง๋ค์ด์ง ์คํ ํฌ๋งท์ผ๋ก ONNX ๋ฐํ์์ ์ฌ๋ฌ ๋ค์ํ ํ๋ซํผ๊ณผ ํ๋์จ์ด(์๋์ฐ, ๋ฆฌ๋ ์ค, ๋งฅ์ ๋น๋กฏํ ํ๋ซํผ ๋ฟ๋ง ์๋๋ผ CPU, GPU ๋ฑ์ ํ๋์จ์ด)์์ ํจ์จ์ ์ธ ์ถ๋ก ์ ๊ฐ๋ฅํ๊ฒ ํ๋ค. ๋๋ฌธ์ ๋ค์ํ ํ๋ ์์ํฌ์์ ์ฐ๊ณ๊ฐ ํ์ํ ๋ ONNX๋ฅผ ์ฌ์ฉํ๋ค. (pytorch ↔๏ธ tensorflow ↔๏ธ caffe2 ↔๏ธ MXNet ↔๏ธ ...)
*์ฐธ๊ณ
- https://pytorch.org/docs/stable/onnx.html
- https://tutorials.pytorch.kr/advanced/super_resolution_with_onnxruntime.html
ONNX ์์
1. Pytorch ๋ชจ๋ธ ๊ตฌํ
# ํ์ํ import๋ฌธ
import io
import numpy as np
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
import torch.nn as nn
import torch.nn.init as init
# ์ ์๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ดํด์๋ ๋ชจ๋ธ ์์ฑ
torch_model = SuperResolutionNet(upscale_factor=3)
2. ํ์ต๋ ์จ์ดํธ ๋ถ๋ฌ์ค๊ธฐ
# ๋ฏธ๋ฆฌ ํ์ต๋ ๊ฐ์ค์น๋ฅผ ์ฝ์ด์ต๋๋ค
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1 # ์์์ ์
# ๋ชจ๋ธ์ ๋ฏธ๋ฆฌ ํ์ต๋ ๊ฐ์ค์น๋ก ์ด๊ธฐํํฉ๋๋ค
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
# ๋ชจ๋ธ์ ์ถ๋ก ๋ชจ๋๋ก ์ ํํฉ๋๋ค
torch_model.eval()
3. ๋ชจ๋ธ ๋ณํ
- Tracing์ด๋ Scripting์ ํตํด์ PyTorch ๋ชจ๋ธ์ ๋ณํํ ์ ์๋๋ฐ ์ด ์์ ์์๋ tracing์ ํตํด ๋ณํ๋ ๋ชจ๋ธ์ ์ฌ์ฉ
- ๋ชจ๋ธ์ ๋ณํํ๊ธฐ ์ํด์๋ torch.onnx.export() ํจ์๋ฅผ ํธ์ถ
- ์ด ํจ์๋ ๋ชจ๋ธ์ ์คํํ์ฌ ๊ทธ ์คํ์ ์ถ์ (trace)ํ ๋ค์ ์ถ์ ๋ ๋ชจ๋ธ์ ์ง์ ๋ ํ์ผ๋ก ๋ด๋ณด๋
- export ํจ์๊ฐ ๋ชจ๋ธ์ ์คํํ๊ธฐ ๋๋ฌธ์, ์ฐ๋ฆฌ๊ฐ ์ง์ ํ ์๋ฅผ ์ ๋ ฅ๊ฐ์ผ๋ก ๋๊ฒจ์ฃผ์ด์ผ ํ๊ณ ์ด ํ ์์ ๊ฐ์ ์๋ง์ ์๋ฃํ๊ณผ shape์ด๋ผ๋ฉด ๋๋คํ ๊ฐ์ด์ด๋ ๋ฌด๊ด
# ๋ชจ๋ธ์ ๋ํ ์
๋ ฅ๊ฐ
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)
# ๋ชจ๋ธ ๋ณํ
torch.onnx.export(torch_model, # ์คํ๋ ๋ชจ๋ธ
x, # ๋ชจ๋ธ ์
๋ ฅ๊ฐ (ํํ ๋๋ ์ฌ๋ฌ ์
๋ ฅ๊ฐ๋ค๋ ๊ฐ๋ฅ)
"super_resolution.onnx", # ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก (ํ์ผ ๋๋ ํ์ผ๊ณผ ์ ์ฌํ ๊ฐ์ฒด ๋ชจ๋ ๊ฐ๋ฅ)
export_params=True, # ๋ชจ๋ธ ํ์ผ ์์ ํ์ต๋ ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ ์ ์ฅํ ์ง์ ์ฌ๋ถ
opset_version=10, # ๋ชจ๋ธ์ ๋ณํํ ๋ ์ฌ์ฉํ ONNX ๋ฒ์
do_constant_folding=True, # ์ต์ ํ์ ์์ํด๋ฉ์ ์ฌ์ฉํ ์ง์ ์ฌ๋ถ
input_names = ['input'], # ๋ชจ๋ธ์ ์
๋ ฅ๊ฐ์ ๊ฐ๋ฆฌํค๋ ์ด๋ฆ
output_names = ['output'], # ๋ชจ๋ธ์ ์ถ๋ ฅ๊ฐ์ ๊ฐ๋ฆฌํค๋ ์ด๋ฆ
dynamic_axes={'input' : {0 : 'batch_size'}, # ๊ฐ๋ณ์ ์ธ ๊ธธ์ด๋ฅผ ๊ฐ์ง ์ฐจ์
'output' : {0 : 'batch_size'}})
์ ์ฅ๋ onnx ํ์ผ์๋ ๋ชจ๋ธ์ ๋งค๊ฐ๋ณ์์ ๋คํธ์ํฌ ๊ตฌ์กฐ๋ฅผํฌํจํ๋ ๋ฐ์ด๋๋ฆฌ ํ๋กํ ์ฝ ๋ฒํผ๊ฐ ํฌํจ๋์ด ์์
๋ํ layer ๊ฐ์ ์ ์ถ๋ ฅ ํฌ๊ธฐ๋ฅผ ํ์ธํ๊ธฐ ์ํด์ ์ ์ฅ๋ ONNX๋ฅผ ๋ค์ ๋ถ๋ฌ์์ ์๋์ ๊ฐ์ ๋ฐฉ์์ผ๋ก shape ์ ๋ณด๋ฅผ ์ ์ฅํ๋ ๊ณผ์ ์ด ํ์
from onnx import shape_inference
path = "./super_resolution.onnx"
onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)
4. ONNX ๋ชจ๋ธ ํ์ธ
- ONNX ๋ฐํ์์์ ๋ณํ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ ๋ ๊ฐ์ ๊ฒฐ๊ณผ๋ฅผ ์ป๋์ง ํ์ธํ๊ธฐ ์ํด์ torch_out ๋ฅผ ๊ณ์ฐ
- ONNX ๋ฐํ์์์์ ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ์ ํ์ธํ๊ธฐ ์ ์ ๋จผ์ ONNX API๋ฅผ ์ฌ์ฉํด ONNX ๋ชจ๋ธ์ ํ์ธ
- onnx.load("super_resolution.onnx") ๋ ์ ์ฅ๋ ๋ชจ๋ธ์ ์ฝ์ด์จ ํ ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ์ทจํฉํ์ฌ ์ ์ฅํ๊ณ ์๋ ์์ ํ์ผ ์ปจํ ์ด๋์ธ onnx.ModelProto๋ฅผ ๋ฐํ
- onnx.checker.check_model(onnx_model) ๋ ๋ชจ๋ธ์ ๊ตฌ์กฐ๋ฅผ ํ์ธํ๊ณ ๋ชจ๋ธ์ด ์ ํจํ ์คํค๋ง(valid schema)๋ฅผ ๊ฐ์ง๊ณ ์๋์ง๋ฅผ ์ฒดํฌ
import onnx
onnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)
5. ONNX ๋ฐํ์๊ณผ pytorch ์ถ๋ ฅ ๋น๊ต
- ONNX ๋ฐํ์์ Python API๋ฅผ ํตํด ๊ฒฐ๊ณผ๊ฐ์ ๊ณ์ฐ
- ์ด ๋ถ๋ถ์ ๋ณดํต ๋ณ๋์ ํ๋ก์ธ์ค ๋๋ ๋ณ๋์ ๋จธ์ ์์ ์คํ๋์ง๋ง, ์ด ํํ ๋ฆฌ์ผ์์๋ ๋ชจ๋ธ์ด ONNX ๋ฐํ์๊ณผ PyTorch์์ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ๋์ง๋ฅผ ํ์ธํ๊ธฐ ์ํด ๋์ผํ ํ๋ก์ธ์ค์์ ๊ณ์ ์คํ
- ๋ชจ๋ธ์ ONNX ๋ฐํ์์์ ์คํํ๊ธฐ ์ํด์๋ ๋ฏธ๋ฆฌ ์ค์ ๋ ์ธ์๋ค๋ก ๋ชจ๋ธ์ ์ํ ์ถ๋ก ์ธ์ ์ ์์ฑ
- ์ธ์ ์ด ์์ฑ๋๋ฉด, ๋ชจ๋ธ์ run() API๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์คํ
- ์ด API์ ๋ฆฌํด๊ฐ์ ONNX ๋ฐํ์์์ ์ฐ์ฐ๋ ๋ชจ๋ธ์ ๊ฒฐ๊ณผ๊ฐ๋ค์ ํฌํจํ๊ณ ์๋ ๋ฆฌ์คํธ
- PyTorch์ ONNX ๋ฐํ์์์ ์ฐ์ฐ๋ ๊ฒฐ๊ณผ๊ฐ์ด ์๋ก ์ผ์นํ๋์ง ์ค์ฐจ๋ฒ์ (rtol=1e-03, atol=1e-05) ์ด๋ด์์ ํ์ธ
import onnxruntime
ort_session = onnxruntime.InferenceSession("super_resolution.onnx")
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# ONNX ๋ฐํ์์์ ๊ณ์ฐ๋ ๊ฒฐ๊ณผ๊ฐ
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
# ONNX ๋ฐํ์๊ณผ PyTorch์์ ์ฐ์ฐ๋ ๊ฒฐ๊ณผ๊ฐ ๋น๊ต
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
*์ฃผ์ ์ฌํญ
PyTorch ๋ชจ๋ธ์ NumPy ๋๋ Python ์ ํ ๋ฐ ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ์์ฑํ ์ ์์ง๋ง tracing ์ค์ NumPy ๋๋ Python ์ ํ์ ๋ชจ๋ ๋ณ์(torch.Tensor๊ฐ ์๋)๋ ์์๋ก ๋ณํ๋๋ฏ๋ก ํด๋น ๊ฐ์ด ๋ค์์ ๋ฐ๋ผ ๋ณ๊ฒฝ๋์ด์ผ ํ๋ ๊ฒฝ์ฐ ์๋ชป๋ ๊ฒฐ๊ณผ๊ฐ ์์ฑ๋๋ค.