๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
๐Ÿ’ป Programming/AI & ML

[Model Inference] Torch-TensorRT ์‚ฌ์šฉ๋ฒ• | ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ์ตœ์ ํ™” ๋ฐ ์ธํผ๋Ÿฐ์Šค ๊ฐ€์†ํ™”

by ๋ญ…์ฆค 2023. 10. 2.
๋ฐ˜์‘ํ˜•

 

 

Torch-TensorRT๋Š” PyTorch์™€ NVIDIA์˜ TensorRT๋ฅผ ํ†ตํ•ฉํ•˜์—ฌ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ์ตœ์ ํ™”ํ•˜๊ณ  ๊ฐ€์†ํ™”ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” PyTorch/TorchScript/FX์šฉ ์ปดํŒŒ์ผ๋Ÿฌ์ด๋‹ค. Torch-TensorRT๋Š” PyTorch ์ต์Šคํ…์…˜์œผ๋กœ ๋™์ž‘ํ•˜๋ฉฐ JIT(Just In Time) ๋Ÿฐํƒ€์ž„์— ์›ํ• ํ•˜๊ฒŒ ํ†ตํ•ฉ๋˜๋Š” ๋ชจ๋“ˆ์„ ์ปดํŒŒ์ผํ•œ๋‹ค. 

 

NVIDIA TensorRT๋Š” NVIDIA GPU์—์„œ ๋ชจ๋ธ์„ ๋” ๋น ๋ฅด๊ฒŒ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•œ ์ตœ์ ํ™”๋œ ๋Ÿฐํƒ€์ž„ ์—”์ง„์œผ๋กœ, ํŠนํžˆ ๋”ฅ ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ๋ฐฐํฌ ํ™˜๊ฒฝ์—์„œ ๋” ํšจ์œจ์ ์œผ๋กœ ์‹คํ–‰ํ•˜๊ณ  ์ถ”๋ก (inference) ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œํ‚ค๋Š” ๋ฐ ์‚ฌ์šฉ๋œ๋‹ค.

 

๊ธฐ์กด์— ํŒŒ์ด์ฌ์œผ๋กœ TensorRT๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์˜คํ”ˆ์†Œ์Šค ์ปค๋ฎค๋‹ˆํ‹ฐ์—์„œ ๊ฐœ๋ฐœํ•œ torch2trt ํŒจํ‚ค์ง€๋ฅผ ์‚ฌ์šฉํ•ด์„œ pytorch ๋ชจ๋ธ์„ tensorRT ํ˜ธํ™˜ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•ด์„œ ๋ชจ๋ธ ์ธํผ๋Ÿฐ์Šค๋ฅผ ๊ฐ€์†ํ™”์‹œ์ผฐ๋‹ค. ํ•˜์ง€๋งŒ NIVIDA์™€ PyTorch๊ฐ€ ๊ณต์‹์ ์œผ๋กœ ์ œ๊ณตํ•˜๋Š” Torch-TensorRT๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด  PyTorch ๋ชจ๋ธ์„ ๋ณ€ํ™˜ํ•  ๋•Œ ์ตœ์ ํ™” ์ˆ˜์ค€์„ ๋” ์„ธ๋ฐ€ํ•˜๊ฒŒ ์ œ์–ดํ•  ์ˆ˜ ์žˆ๋‹ค. 

Torch-TensorRT๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์ฃผ์š” ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•œ๋‹ค

  • ๋ชจ๋ธ ์ตœ์ ํ™”: Torch-TensorRT๋Š” PyTorch ๋ชจ๋ธ์„ TensorRT ํ˜ธํ™˜ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ TensorRT์˜ ์ตœ์ ํ™” ๊ธฐ๋Šฅ์„ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ์Œ. ์ด๋กœ์จ ๋ชจ๋ธ์˜ ์—ฐ์‚ฐ ๊ทธ๋ž˜ํ”„๊ฐ€ ์ตœ์ ํ™”๋˜๊ณ , GPU ๊ฐ€์†์„ ์œ„ํ•œ ํŠน์ • ์—ฐ์‚ฐ ์ปค๋„๋กœ ๋Œ€์ฒด๋จ.
  • TensorRT ๋ชจ๋ธ ์ƒ์„ฑ: Torch-TensorRT๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ตœ์ ํ™”๋œ ๋ชจ๋ธ์„ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ์Œ. ์ด๋ ‡๊ฒŒ ์ƒ์„ฑ๋œ ๋ชจ๋ธ์€ TensorRT ์—”์ง„์„ ์‚ฌ์šฉํ•˜์—ฌ ํšจ์œจ์ ์œผ๋กœ ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ์Œ.
  • End-to-end ๊ฐ€์†ํ™”: Torch-TensorRT๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด PyTorch ๋ชจ๋ธ์„ ๊ฐ€์ ธ์™€ TensorRT ๋ชจ๋ธ๋กœ ๋ณ€ํ™˜ํ•˜๊ณ  ์ถ”๋ก ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ์ „์ฒด ์—”๋“œ ํˆฌ ์—”๋“œ ํŒŒ์ดํ”„๋ผ์ธ์„ ๊ตฌ์ถ•ํ•  ์ˆ˜ ์žˆ์Œ.
  • ๋™์  ๊ทธ๋ž˜ํ”„ ์ง€์›: PyTorch์˜ ๋™์  ๊ทธ๋ž˜ํ”„ ํŠน์„ฑ์„ ์œ ์ง€ํ•˜๋ฉด์„œ TensorRT๋กœ์˜ ์ตœ์ ํ™”๊ฐ€ ๊ฐ€๋Šฅํ•˜๋ฏ€๋กœ, ๋‹ค์–‘ํ•œ ๋ชจ๋ธ ์•„ํ‚คํ…์ฒ˜์— ๋Œ€ํ•œ ์ง€์›์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

 

Torch-TensorRT๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋”ฅ ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ NVIDIA GPU์—์„œ ๋” ํšจ์œจ์ ์œผ๋กœ ์‹คํ–‰ํ•˜๊ณ , ์ถ”๋ก  ์†๋„๋ฅผ ํฌ๊ฒŒ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด๋Š” ์‹ค์‹œ๊ฐ„ ์š”๊ตฌ ์‚ฌํ•ญ์ด ์žˆ๋Š” ์‘์šฉ ํ”„๋กœ๊ทธ๋žจ ๋ฐ ์„œ๋น„์Šค์— ํŠนํžˆ ์œ ์šฉํ•˜๋‹ค. 


Torch-TensorRT Install

 

Torch-TensorRT — Torch-TensorRT v2.2.0.dev0+8ebf24d documentation

Shortcuts

pytorch.org

๋ฒ„์ „์ด ๋งž๋Š” pytorch๊ฐ€ ์„ค์น˜๋˜์–ด ์žˆ๋‹ค๋ฉด tensorrt, torch_tensorrt ๋งŒ ์„ค์น˜ํ•˜๋ฉด ์˜ˆ์ œ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•ด ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

pip install tensorrt torch_tensorrt

 

Torch-TensorRT ๊นƒํ—™ Release ํŽ˜์ด์ง€

ํ•˜์ง€๋งŒ ํ•œ๋ฒˆ์— ์ œ๋Œ€๋กœ ๋™์ž‘ํ•˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ๋„ ๋งŽ๋‹ค... ใ…  TensorRT๋Š” ํ™˜๊ฒฝ ๊ตฌ์„ฑ์ด ์ƒ๊ฐ๋ณด๋‹ค ํž˜๋“ค๊ธฐ ๋•Œ๋ฌธ์— ์ž˜ ๋™์ž‘ํ•˜๋Š” ํŒจํ‚ค์ง€ ๋ฒ„์ „ ์กฐํ•ฉ์„ ์ž˜ ํ™•์ธํ•ด์„œ ์„ค์น˜ํ•ด์•ผ ํ•œ๋‹ค. Torch-TensorRT ๊นƒํ—™ ๋ฆด๋ฆฌ์Šค ํŽ˜์ด์ง€์—์„œ Torch-TensorRT ๋ฒ„์ „๋ณ„ pytorch, cuda, tensorrt์™€์˜ ๋ฒ„์ „ ํ˜ธํ™˜์„ฑ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

https://github.com/pytorch/TensorRT/releases

 

๊ฐœ์ธ์ ์œผ๋กœ๋Š” TensorRT 8.6.1์ด ์„ค์น˜๋œ ์ธ์Šคํ„ด์Šค๋ฅผ ๋„์›Œ์„œ pytorch 2.0 cuda 11.8 ์„ ์„ค์น˜ํ•œ ๋’ค์— torch-tensorrt v1.4๋ฅผ ์„ค์น˜ํ–ˆ๋‹ค. ๊ทธ๋žฌ๋”๋‹ˆ... torchvision์˜ ๋ฒ„์ „์ด ํ˜ธํ™˜์ด ์•ˆ๋œ๋‹ค๊ณ  ํ•ด์„œ torchvision๋งŒ ์ง€์› ๋‹ค๊ฐ€ ๋‹ค์‹œ ์„ค์น˜ํ•˜๋‹ˆ๊นŒ ์ž˜ ๋™์ž‘ํ•œ๋‹ค.

 

Python Torch-TensorRT ์‚ฌ์šฉ ์˜ˆ์‹œ

Torch-TensorRT๋กœ ์ž…๋ ฅ์„ ์ปดํŒŒ์ผํ•˜๋ ค๋ฉด torch.nn.ModuleTorch-TensorRT์— ๋ชจ๋“ˆ๊ณผ ์ž…๋ ฅ์„ ์ œ๊ณตํ•˜๋ฉด ๋œ๋‹ค. ๊ทธ๋Ÿฌ๋ฉด ๋‹ค๋ฅธ PyTorch ๋ชจ๋“ˆ์„ ์‹คํ–‰ํ•˜๊ฑฐ๋‚˜ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ๋„๋ก ์ตœ์ ํ™”๋œ TorchScript ๋ชจ๋“ˆ์ด ๋ฐ˜ํ™˜๋œ๋‹ค.

 

์ž…๋ ฅ์€ ์ž…๋ ฅ ํ…์„œ์˜ ๋ชจ์–‘, ๋ฐ์ดํ„ฐ ์œ ํ˜• ๋ฐ ๋ฉ”๋ชจ๋ฆฌ ํ˜•์‹์„ ์ •์˜ํ•˜๋Š” torch_tensorrt.Input ํด๋ž˜์Šค list์ด๋‹ค. ์ž…๋ ฅ์ด ํŠœํ”Œ์ด๋‚˜ Tensor ๋ชฉ๋ก๊ณผ ๊ฐ™์€ ๋” ๋ณต์žกํ•œ ๋ฐ์ดํ„ฐ ์œ ํ˜•์ธ ๊ฒฝ์šฐ ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ input_signature์™€ ๊ฐ™์€ ์ปฌ๋ ‰์…˜ ๊ธฐ๋ฐ˜ ์ž…๋ ฅ(e.g. List[Tensor], Tuple[Tensor, Tensor])์„ ์ง€์ •ํ•  ์ˆ˜ ์žˆ๋‹ค(๋‘ ๋ฒˆ์งธ ์˜ˆ์ œ ์ฝ”๋“œ).

 

์—”์ง„์ด๋‚˜ ๋Œ€์ƒ ์žฅ์น˜์˜ ์ž‘๋™ ์ •๋ฐ€๋„์™€ ๊ฐ™์€ ์„ค์ •์„ ์ง€์ •ํ•  ์ˆ˜๋„ ์žˆ๊ณ , ์ปดํŒŒ์ผ ํ›„์—๋Š” ๋‹ค๋ฅธ ๋ชจ๋“ˆ๊ณผ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ๋ชจ๋“ˆ์„ ์ €์žฅํ•˜์—ฌ ๋ฐฐํฌ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์— ๋กœ๋“œํ•  ์ˆ˜ ์žˆ๋‹ค. TensorRT/TorchScript ๋ชจ๋“ˆ์„ ๋กœ๋“œํ•˜๋ ค๋ฉด ๋จผ์ € torch_tensorrt๋ฅผ ์ž„ํฌํŠธ ํ•ด์•ผ ํ•œ๋‹ค.

import torch_tensorrt

...

model = MyModel().eval()  # torch module needs to be in eval (not training) mode

inputs = [
    torch_tensorrt.Input(
        min_shape=[1, 1, 16, 16],
        opt_shape=[1, 1, 32, 32],
        max_shape=[1, 1, 64, 64],
        dtype=torch.half,
    )
]
enabled_precisions = {torch.float, torch.half}  # Run with fp16

trt_ts_module = torch_tensorrt.compile(
    model, inputs=inputs, enabled_precisions=enabled_precisions
)

input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
  • ์„ค์น˜ํ•œ torch_tensorrt๋ฅผ ์ž„ํฌํŠธ
  • pytorch ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๊ณ  torch_tensor.complie ์ˆ˜ํ–‰
  • ์ปดํŒŒ์ผ ์‹œ input์„ ์ž…๋ ฅํ•˜๋Š”๋ฐ, ์ด ๋•Œ ๋ชจ๋ธ ์ž…๋ ฅ์˜ ์ตœ์†Œ ํฌ๊ธฐ,  ์ตœ์ ํ™” ํฌ๊ธฐ, ์ตœ๋Œ€ ํฌ๊ธฐ๋ฅผ ์ง€์ •
  • ์ดํ›„ ๋ณ€ํ™˜ํ•œ ๋ชจ๋ธ์— input์„ ์ž…๋ ฅํ•ด ์ฃผ๊ธฐ๋งŒ ํ•˜๋ฉด ๋จ
  • torch.jit.save๋กœ ๋ณ€ํ™˜ํ•œ ๋ชจ๋ธ ์ €์žฅ ๊ฐ€๋Šฅ

*์ž…๋ ฅ ํฌ๊ธฐ๋ฅผ ๋™์ ์œผ๋กœ ์กฐ์ •ํ•  ์ˆ˜ ์žˆ๊ธด ํ•˜์ง€๋งŒ, ๋ชจ๋ธ ์ •ํ™•๋„์™€ ์†๋„๋Š” opt_shape (์ตœ์ ํ™”๋œ ํฌ๊ธฐ)์—์„œ๋งŒ ๋ณด์žฅ ๋ฐ›์„ ์ˆ˜ ์žˆ์Œ. ์ž…๋ ฅ ํฌ๊ธฐ๊ฐ€ opt_shape๊ณผ ์ฐจ์ด๊ฐ€ ๋งŽ์ด ๋‚ ์ˆ˜๋ก pytorch ๋ชจ๋ธ ์ถœ๋ ฅ๊ณผ tensorrt๋กœ ๋ณ€ํ™˜ํ•œ ๋ชจ๋ธ์˜ ์ถœ๋ ฅ๊ฐ’์˜ ์ฐจ์ด๊ฐ€ ๋งŽ์ด ๋‚จ.

 

# Sample using collection-based inputs via the input_signature argument
import torch_tensorrt

...

model = MyModel().eval()

# input_signature expects a tuple of individual input arguments to the module
# The module below, for example, would have a docstring of the form:
# def forward(self, input0: List[torch.Tensor], input1: Tuple[torch.Tensor, torch.Tensor])
input_signature = (
    [torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)],
    (torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)),
)
enabled_precisions = {torch.float, torch.half}

trt_ts_module = torch_tensorrt.compile(
    model, input_signature=input_signature, enabled_precisions=enabled_precisions
)

input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
  • input_signature๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ์˜ˆ์‹œ
  • ๋‹ค์–‘ํ•œ ํ˜•ํƒœ์˜ ์ž…๋ ฅ์„ ๋ฐ›์„ ์ˆ˜ ์žˆ์Œ

 

# Deployment application
import torch
import torch_tensorrt

trt_ts_module = torch.jit.load("trt_ts_module.ts")
input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
  • tensorRT๋กœ ๋ณ€ํ™˜ & ์ €์žฅํ•œ ๋ชจ๋ธ์„ ๋กœ๋“œํ•ด์„œ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•

 

# PyTorch ๋ชจ๋ธ ์ถ”๋ก 
torch_output = model(input_data)

# TensorRT ๋ชจ๋ธ ์ถ”๋ก 
trt_output = trt_ts_module(input_data)

# ๋‘ ์ถœ๋ ฅ ๊ฐ„์˜ ์œ ์‚ฌ์„ฑ ํ™•์ธ
if torch.allclose(torch_output, trt_output, atol=1e-3):
    print("PyTorch์™€ TensorRT ๋ชจ๋ธ์˜ ์ถœ๋ ฅ์ด ์œ ์‚ฌํ•ฉ๋‹ˆ๋‹ค.")
else:
    print("PyTorch์™€ TensorRT ๋ชจ๋ธ์˜ ์ถœ๋ ฅ์ด ๋‹ค๋ฆ…๋‹ˆ๋‹ค.")
  • pytorch ๋ชจ๋ธ๊ณผ tensorrt ๋ชจ๋ธ์˜ ์ถœ๋ ฅ ์œ ์‚ฌ์„ฑ ํ…Œ์ŠคํŠธ
๋ฐ˜์‘ํ˜•