๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[ONNX] pytorch ๋ชจ๋ธ์„ ONNX๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ์‹คํ–‰ํ•˜๊ธฐ

by ๋ญ…์ฆค 2022. 12. 21.
๋ฐ˜์‘ํ˜•
ONNX (Open Neural Network eXchange)

 

ONNX๋Š” ๊ธฐ๊ณ„ ํ•™์Šต ๋ชจ๋ธ์„ ํ‘œํ˜„ํ•˜๊ธฐ ์œ„ํ•ด ๋งŒ๋“ค์–ด์ง„ ์˜คํ”ˆ ํฌ๋งท์œผ๋กœ ONNX ๋Ÿฐํƒ€์ž„์€ ์—ฌ๋Ÿฌ ๋‹ค์–‘ํ•œ ํ”Œ๋žซํผ๊ณผ ํ•˜๋“œ์›จ์–ด(์œˆ๋„์šฐ, ๋ฆฌ๋ˆ…์Šค, ๋งฅ์„ ๋น„๋กฏํ•œ ํ”Œ๋žซํผ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ CPU, GPU ๋“ฑ์˜ ํ•˜๋“œ์›จ์–ด)์—์„œ ํšจ์œจ์ ์ธ ์ถ”๋ก ์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•œ๋‹ค. ๋•Œ๋ฌธ์— ๋‹ค์–‘ํ•œ ํ”„๋ ˆ์ž„์›Œํฌ์™€์˜ ์—ฐ๊ณ„๊ฐ€ ํ•„์š”ํ•  ๋•Œ ONNX๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค. (pytorch ↔๏ธŽ tensorflow ↔๏ธŽ caffe2 ↔๏ธŽ MXNet ↔๏ธŽ ...)

 

 

*์ฐธ๊ณ 

 

 

ONNX ์˜ˆ์ œ

 

1. Pytorch ๋ชจ๋ธ ๊ตฌํ˜„

 

# ํ•„์š”ํ•œ import๋ฌธ
import io
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

import torch.nn as nn
import torch.nn.init as init

# ์ •์˜๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ดˆํ•ด์ƒ๋„ ๋ชจ๋ธ ์ƒ์„ฑ
torch_model = SuperResolutionNet(upscale_factor=3)

 

 

2. ํ•™์Šต๋œ ์›จ์ดํŠธ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

 

# ๋ฏธ๋ฆฌ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜๋ฅผ ์ฝ์–ด์˜ต๋‹ˆ๋‹ค
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1    # ์ž„์˜์˜ ์ˆ˜

# ๋ชจ๋ธ์„ ๋ฏธ๋ฆฌ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜๋กœ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))

# ๋ชจ๋ธ์„ ์ถ”๋ก  ๋ชจ๋“œ๋กœ ์ „ํ™˜ํ•ฉ๋‹ˆ๋‹ค
torch_model.eval()

 

 

3. ๋ชจ๋ธ ๋ณ€ํ™˜

 

  • Tracing์ด๋‚˜ Scripting์„ ํ†ตํ•ด์„œ PyTorch ๋ชจ๋ธ์„ ๋ณ€ํ™˜ํ•  ์ˆ˜ ์žˆ๋Š”๋ฐ ์ด ์˜ˆ์ œ์—์„œ๋Š” tracing์„ ํ†ตํ•ด ๋ณ€ํ™˜๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉ
  • ๋ชจ๋ธ์„ ๋ณ€ํ™˜ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” torch.onnx.export() ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœ
  • ์ด ํ•จ์ˆ˜๋Š” ๋ชจ๋ธ์„ ์‹คํ–‰ํ•˜์—ฌ ๊ทธ ์‹คํ–‰์„ ์ถ”์ (trace)ํ•œ ๋‹ค์Œ ์ถ”์ ๋œ ๋ชจ๋ธ์„ ์ง€์ •๋œ ํŒŒ์ผ๋กœ ๋‚ด๋ณด๋ƒ„
  • export ํ•จ์ˆ˜๊ฐ€ ๋ชจ๋ธ์„ ์‹คํ–‰ํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ์šฐ๋ฆฌ๊ฐ€ ์ง์ ‘ ํ…์„œ๋ฅผ ์ž…๋ ฅ๊ฐ’์œผ๋กœ ๋„˜๊ฒจ์ฃผ์–ด์•ผ ํ•˜๊ณ  ์ด ํ…์„œ์˜ ๊ฐ’์€ ์•Œ๋งž์€ ์ž๋ฃŒํ˜•๊ณผ shape์ด๋ผ๋ฉด ๋žœ๋คํ•œ ๊ฐ’์ด์–ด๋„ ๋ฌด๊ด€
# ๋ชจ๋ธ์— ๋Œ€ํ•œ ์ž…๋ ฅ๊ฐ’
x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
torch_out = torch_model(x)

# ๋ชจ๋ธ ๋ณ€ํ™˜
torch.onnx.export(torch_model,               # ์‹คํ–‰๋  ๋ชจ๋ธ
                  x,                         # ๋ชจ๋ธ ์ž…๋ ฅ๊ฐ’ (ํŠœํ”Œ ๋˜๋Š” ์—ฌ๋Ÿฌ ์ž…๋ ฅ๊ฐ’๋“ค๋„ ๊ฐ€๋Šฅ)
                  "super_resolution.onnx",   # ๋ชจ๋ธ ์ €์žฅ ๊ฒฝ๋กœ (ํŒŒ์ผ ๋˜๋Š” ํŒŒ์ผ๊ณผ ์œ ์‚ฌํ•œ ๊ฐ์ฒด ๋ชจ๋‘ ๊ฐ€๋Šฅ)
                  export_params=True,        # ๋ชจ๋ธ ํŒŒ์ผ ์•ˆ์— ํ•™์Šต๋œ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๋ฅผ ์ €์žฅํ• ์ง€์˜ ์—ฌ๋ถ€
                  opset_version=10,          # ๋ชจ๋ธ์„ ๋ณ€ํ™˜ํ•  ๋•Œ ์‚ฌ์šฉํ•  ONNX ๋ฒ„์ „
                  do_constant_folding=True,  # ์ตœ์ ํ™”์‹œ ์ƒ์ˆ˜ํด๋”ฉ์„ ์‚ฌ์šฉํ• ์ง€์˜ ์—ฌ๋ถ€
                  input_names = ['input'],   # ๋ชจ๋ธ์˜ ์ž…๋ ฅ๊ฐ’์„ ๊ฐ€๋ฆฌํ‚ค๋Š” ์ด๋ฆ„
                  output_names = ['output'], # ๋ชจ๋ธ์˜ ์ถœ๋ ฅ๊ฐ’์„ ๊ฐ€๋ฆฌํ‚ค๋Š” ์ด๋ฆ„
                  dynamic_axes={'input' : {0 : 'batch_size'},    # ๊ฐ€๋ณ€์ ์ธ ๊ธธ์ด๋ฅผ ๊ฐ€์ง„ ์ฐจ์›
                                'output' : {0 : 'batch_size'}})

์ €์žฅ๋œ onnx ํŒŒ์ผ์—๋Š” ๋ชจ๋ธ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜์™€ ๋„คํŠธ์›Œํฌ ๊ตฌ์กฐ๋ฅผํฌํ•จํ•˜๋Š” ๋ฐ”์ด๋„ˆ๋ฆฌ ํ”„๋กœํ† ์ฝœ ๋ฒ„ํผ๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ์Œ

 

๋˜ํ•œ layer ๊ฐ„์˜ ์ž…์ถœ๋ ฅ ํฌ๊ธฐ๋ฅผ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด์„œ ์ €์žฅ๋œ ONNX๋ฅผ ๋‹ค์‹œ ๋ถˆ๋Ÿฌ์™€์„œ ์•„๋ž˜์™€ ๊ฐ™์€ ๋ฐฉ์‹์œผ๋กœ shape ์ •๋ณด๋ฅผ ์ €์žฅํ•˜๋Š” ๊ณผ์ •์ด ํ•„์š”

from onnx import shape_inference

path = "./super_resolution.onnx"
onnx.save(onnx.shape_inference.infer_shapes(onnx.load(path)), path)

 

 

4. ONNX ๋ชจ๋ธ ํ™•์ธ

 

  • ONNX ๋Ÿฐํƒ€์ž„์—์„œ ๋ณ€ํ™˜๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ ๊ฐ™์€ ๊ฒฐ๊ณผ๋ฅผ ์–ป๋Š”์ง€ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด์„œ torch_out ๋ฅผ ๊ณ„์‚ฐ
  • ONNX ๋Ÿฐํƒ€์ž„์—์„œ์˜ ๋ชจ๋ธ ๊ฒฐ๊ณผ๊ฐ’์„ ํ™•์ธํ•˜๊ธฐ ์ „์— ๋จผ์ € ONNX API๋ฅผ ์‚ฌ์šฉํ•ด ONNX ๋ชจ๋ธ์„ ํ™•์ธ
    • onnx.load("super_resolution.onnx") ๋Š” ์ €์žฅ๋œ ๋ชจ๋ธ์„ ์ฝ์–ด์˜จ ํ›„ ๋จธ์‹ ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ์ทจํ•ฉํ•˜์—ฌ ์ €์žฅํ•˜๊ณ  ์žˆ๋Š” ์ƒ์œ„ ํŒŒ์ผ ์ปจํ…Œ์ด๋„ˆ์ธ onnx.ModelProto๋ฅผ ๋ฐ˜ํ™˜
    • onnx.checker.check_model(onnx_model) ๋Š” ๋ชจ๋ธ์˜ ๊ตฌ์กฐ๋ฅผ ํ™•์ธํ•˜๊ณ  ๋ชจ๋ธ์ด ์œ ํšจํ•œ ์Šคํ‚ค๋งˆ(valid schema)๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š”์ง€๋ฅผ ์ฒดํฌ

 

import onnx

onnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)

 

 

5. ONNX ๋Ÿฐํƒ€์ž„๊ณผ pytorch ์ถœ๋ ฅ ๋น„๊ต

 

  • ONNX ๋Ÿฐํƒ€์ž„์˜ Python API๋ฅผ ํ†ตํ•ด ๊ฒฐ๊ณผ๊ฐ’์„ ๊ณ„์‚ฐ
  • ์ด ๋ถ€๋ถ„์€ ๋ณดํ†ต ๋ณ„๋„์˜ ํ”„๋กœ์„ธ์Šค ๋˜๋Š” ๋ณ„๋„์˜ ๋จธ์‹ ์—์„œ ์‹คํ–‰๋˜์ง€๋งŒ, ์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ๋ชจ๋ธ์ด ONNX ๋Ÿฐํƒ€์ž„๊ณผ PyTorch์—์„œ ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•˜๋Š”์ง€๋ฅผ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•ด ๋™์ผํ•œ ํ”„๋กœ์„ธ์Šค์—์„œ ๊ณ„์† ์‹คํ–‰
  • ๋ชจ๋ธ์„ ONNX ๋Ÿฐํƒ€์ž„์—์„œ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋ฏธ๋ฆฌ ์„ค์ •๋œ ์ธ์ž๋“ค๋กœ ๋ชจ๋ธ์„ ์œ„ํ•œ ์ถ”๋ก  ์„ธ์…˜์„ ์ƒ์„ฑ
  • ์„ธ์…˜์ด ์ƒ์„ฑ๋˜๋ฉด, ๋ชจ๋ธ์˜ run() API๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ์‹คํ–‰
  • ์ด API์˜ ๋ฆฌํ„ด๊ฐ’์€ ONNX ๋Ÿฐํƒ€์ž„์—์„œ ์—ฐ์‚ฐ๋œ ๋ชจ๋ธ์˜ ๊ฒฐ๊ณผ๊ฐ’๋“ค์„ ํฌํ•จํ•˜๊ณ  ์žˆ๋Š” ๋ฆฌ์ŠคํŠธ
  • PyTorch์™€ ONNX ๋Ÿฐํƒ€์ž„์—์„œ ์—ฐ์‚ฐ๋œ ๊ฒฐ๊ณผ๊ฐ’์ด ์„œ๋กœ ์ผ์น˜ํ•˜๋Š”์ง€ ์˜ค์ฐจ๋ฒ”์œ„ (rtol=1e-03, atol=1e-05) ์ด๋‚ด์—์„œ ํ™•์ธ 
import onnxruntime

ort_session = onnxruntime.InferenceSession("super_resolution.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# ONNX ๋Ÿฐํƒ€์ž„์—์„œ ๊ณ„์‚ฐ๋œ ๊ฒฐ๊ณผ๊ฐ’
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# ONNX ๋Ÿฐํƒ€์ž„๊ณผ PyTorch์—์„œ ์—ฐ์‚ฐ๋œ ๊ฒฐ๊ณผ๊ฐ’ ๋น„๊ต
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

 

*์ฃผ์˜ ์‚ฌํ•ญ

PyTorch ๋ชจ๋ธ์€ NumPy ๋˜๋Š” Python ์œ ํ˜• ๋ฐ ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ์ง€๋งŒ tracing ์ค‘์— NumPy ๋˜๋Š” Python ์œ ํ˜•์˜ ๋ชจ๋“  ๋ณ€์ˆ˜(torch.Tensor๊ฐ€ ์•„๋‹˜)๋Š” ์ƒ์ˆ˜๋กœ ๋ณ€ํ™˜๋˜๋ฏ€๋กœ ํ•ด๋‹น ๊ฐ’์ด ๋‹ค์Œ์— ๋”ฐ๋ผ ๋ณ€๊ฒฝ๋˜์–ด์•ผ ํ•˜๋Š” ๊ฒฝ์šฐ ์ž˜๋ชป๋œ ๊ฒฐ๊ณผ๊ฐ€ ์ƒ์„ฑ๋œ๋‹ค.

 

 

 

๋ฐ˜์‘ํ˜•