Torch-TensorRT๋ PyTorch์ NVIDIA์ TensorRT๋ฅผ ํตํฉํ์ฌ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ์ต์ ํํ๊ณ ๊ฐ์ํํ๋ ๋ฐ ์ฌ์ฉ๋๋ PyTorch/TorchScript/FX์ฉ ์ปดํ์ผ๋ฌ์ด๋ค. Torch-TensorRT๋ PyTorch ์ต์คํ ์ ์ผ๋ก ๋์ํ๋ฉฐ JIT(Just In Time) ๋ฐํ์์ ์ํ ํ๊ฒ ํตํฉ๋๋ ๋ชจ๋์ ์ปดํ์ผํ๋ค.
NVIDIA TensorRT๋ NVIDIA GPU์์ ๋ชจ๋ธ์ ๋ ๋น ๋ฅด๊ฒ ์คํํ๊ธฐ ์ํ ์ต์ ํ๋ ๋ฐํ์ ์์ง์ผ๋ก, ํนํ ๋ฅ ๋ฌ๋ ๋ชจ๋ธ์ ๋ฐฐํฌ ํ๊ฒฝ์์ ๋ ํจ์จ์ ์ผ๋ก ์คํํ๊ณ ์ถ๋ก (inference) ์ฑ๋ฅ์ ํฅ์์ํค๋ ๋ฐ ์ฌ์ฉ๋๋ค.
๊ธฐ์กด์ ํ์ด์ฌ์ผ๋ก TensorRT๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด์๋ ์คํ์์ค ์ปค๋ฎค๋ํฐ์์ ๊ฐ๋ฐํ torch2trt ํจํค์ง๋ฅผ ์ฌ์ฉํด์ pytorch ๋ชจ๋ธ์ tensorRT ํธํ ํ์์ผ๋ก ๋ณํํด์ ๋ชจ๋ธ ์ธํผ๋ฐ์ค๋ฅผ ๊ฐ์ํ์์ผฐ๋ค. ํ์ง๋ง NIVIDA์ PyTorch๊ฐ ๊ณต์์ ์ผ๋ก ์ ๊ณตํ๋ Torch-TensorRT๋ฅผ ์ฌ์ฉํ๋ฉด PyTorch ๋ชจ๋ธ์ ๋ณํํ ๋ ์ต์ ํ ์์ค์ ๋ ์ธ๋ฐํ๊ฒ ์ ์ดํ ์ ์๋ค.
Torch-TensorRT๋ ๋ค์๊ณผ ๊ฐ์ ์ฃผ์ ๊ธฐ๋ฅ์ ์ ๊ณตํ๋ค
- ๋ชจ๋ธ ์ต์ ํ: Torch-TensorRT๋ PyTorch ๋ชจ๋ธ์ TensorRT ํธํ ํ์์ผ๋ก ๋ณํํ์ฌ TensorRT์ ์ต์ ํ ๊ธฐ๋ฅ์ ํ์ฉํ ์ ์์. ์ด๋ก์จ ๋ชจ๋ธ์ ์ฐ์ฐ ๊ทธ๋ํ๊ฐ ์ต์ ํ๋๊ณ , GPU ๊ฐ์์ ์ํ ํน์ ์ฐ์ฐ ์ปค๋๋ก ๋์ฒด๋จ.
- TensorRT ๋ชจ๋ธ ์์ฑ: Torch-TensorRT๋ฅผ ์ฌ์ฉํ์ฌ ์ต์ ํ๋ ๋ชจ๋ธ์ ์์ฑํ ์ ์์. ์ด๋ ๊ฒ ์์ฑ๋ ๋ชจ๋ธ์ TensorRT ์์ง์ ์ฌ์ฉํ์ฌ ํจ์จ์ ์ผ๋ก ์ถ๋ก ์ ์ํํ ์ ์์.
- End-to-end ๊ฐ์ํ: Torch-TensorRT๋ฅผ ์ฌ์ฉํ๋ฉด PyTorch ๋ชจ๋ธ์ ๊ฐ์ ธ์ TensorRT ๋ชจ๋ธ๋ก ๋ณํํ๊ณ ์ถ๋ก ์ ์ํํ๋ ์ ์ฒด ์๋ ํฌ ์๋ ํ์ดํ๋ผ์ธ์ ๊ตฌ์ถํ ์ ์์.
- ๋์ ๊ทธ๋ํ ์ง์: PyTorch์ ๋์ ๊ทธ๋ํ ํน์ฑ์ ์ ์งํ๋ฉด์ TensorRT๋ก์ ์ต์ ํ๊ฐ ๊ฐ๋ฅํ๋ฏ๋ก, ๋ค์ํ ๋ชจ๋ธ ์ํคํ ์ฒ์ ๋ํ ์ง์์ ์ ๊ณตํฉ๋๋ค.
Torch-TensorRT๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฅ ๋ฌ๋ ๋ชจ๋ธ์ NVIDIA GPU์์ ๋ ํจ์จ์ ์ผ๋ก ์คํํ๊ณ , ์ถ๋ก ์๋๋ฅผ ํฌ๊ฒ ํฅ์์ํฌ ์ ์์ผ๋ฉฐ, ์ด๋ ์ค์๊ฐ ์๊ตฌ ์ฌํญ์ด ์๋ ์์ฉ ํ๋ก๊ทธ๋จ ๋ฐ ์๋น์ค์ ํนํ ์ ์ฉํ๋ค.
Torch-TensorRT — Torch-TensorRT v2.2.0.dev0+8ebf24d documentation
Shortcuts
pytorch.org
๋ฒ์ ์ด ๋ง๋ pytorch๊ฐ ์ค์น๋์ด ์๋ค๋ฉด tensorrt, torch_tensorrt ๋ง ์ค์นํ๋ฉด ์์ ์ฝ๋๋ฅผ ์คํํด ๋ณผ ์ ์๋ค.
pip install tensorrt torch_tensorrt
ํ์ง๋ง ํ๋ฒ์ ์ ๋๋ก ๋์ํ์ง ์๋ ๊ฒฝ์ฐ๋ ๋ง๋ค... ใ TensorRT๋ ํ๊ฒฝ ๊ตฌ์ฑ์ด ์๊ฐ๋ณด๋ค ํ๋ค๊ธฐ ๋๋ฌธ์ ์ ๋์ํ๋ ํจํค์ง ๋ฒ์ ์กฐํฉ์ ์ ํ์ธํด์ ์ค์นํด์ผ ํ๋ค. Torch-TensorRT ๊นํ ๋ฆด๋ฆฌ์ค ํ์ด์ง์์ Torch-TensorRT ๋ฒ์ ๋ณ pytorch, cuda, tensorrt์์ ๋ฒ์ ํธํ์ฑ์ ํ์ธํ ์ ์๋ค.
→ https://github.com/pytorch/TensorRT/releases
๊ฐ์ธ์ ์ผ๋ก๋ TensorRT 8.6.1์ด ์ค์น๋ ์ธ์คํด์ค๋ฅผ ๋์์ pytorch 2.0 cuda 11.8 ์ ์ค์นํ ๋ค์ torch-tensorrt v1.4๋ฅผ ์ค์นํ๋ค. ๊ทธ๋ฌ๋๋... torchvision์ ๋ฒ์ ์ด ํธํ์ด ์๋๋ค๊ณ ํด์ torchvision๋ง ์ง์ ๋ค๊ฐ ๋ค์ ์ค์นํ๋๊น ์ ๋์ํ๋ค.
Python Torch-TensorRT ์ฌ์ฉ ์์
Torch-TensorRT๋ก ์ ๋ ฅ์ ์ปดํ์ผํ๋ ค๋ฉด torch.nn.ModuleTorch-TensorRT์ ๋ชจ๋๊ณผ ์ ๋ ฅ์ ์ ๊ณตํ๋ฉด ๋๋ค. ๊ทธ๋ฌ๋ฉด ๋ค๋ฅธ PyTorch ๋ชจ๋์ ์คํํ๊ฑฐ๋ ์ถ๊ฐํ ์ ์๋๋ก ์ต์ ํ๋ TorchScript ๋ชจ๋์ด ๋ฐํ๋๋ค.
์ ๋ ฅ์ ์ ๋ ฅ ํ ์์ ๋ชจ์, ๋ฐ์ดํฐ ์ ํ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ํ์์ ์ ์ํ๋ torch_tensorrt.Input ํด๋์ค list์ด๋ค. ์ ๋ ฅ์ด ํํ์ด๋ Tensor ๋ชฉ๋ก๊ณผ ๊ฐ์ ๋ ๋ณต์กํ ๋ฐ์ดํฐ ์ ํ์ธ ๊ฒฝ์ฐ ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ input_signature์ ๊ฐ์ ์ปฌ๋ ์ ๊ธฐ๋ฐ ์ ๋ ฅ(e.g. List[Tensor], Tuple[Tensor, Tensor])์ ์ง์ ํ ์ ์๋ค(๋ ๋ฒ์งธ ์์ ์ฝ๋).
์์ง์ด๋ ๋์ ์ฅ์น์ ์๋ ์ ๋ฐ๋์ ๊ฐ์ ์ค์ ์ ์ง์ ํ ์๋ ์๊ณ , ์ปดํ์ผ ํ์๋ ๋ค๋ฅธ ๋ชจ๋๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ๋ชจ๋์ ์ ์ฅํ์ฌ ๋ฐฐํฌ ์ ํ๋ฆฌ์ผ์ด์ ์ ๋ก๋ํ ์ ์๋ค. TensorRT/TorchScript ๋ชจ๋์ ๋ก๋ํ๋ ค๋ฉด ๋จผ์ torch_tensorrt๋ฅผ ์ํฌํธ ํด์ผ ํ๋ค.
import torch_tensorrt
...
model = MyModel().eval() # torch module needs to be in eval (not training) mode
inputs = [
torch_tensorrt.Input(
min_shape=[1, 1, 16, 16],
opt_shape=[1, 1, 32, 32],
max_shape=[1, 1, 64, 64],
dtype=torch.half,
)
]
enabled_precisions = {torch.float, torch.half} # Run with fp16
trt_ts_module = torch_tensorrt.compile(
model, inputs=inputs, enabled_precisions=enabled_precisions
)
input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
- ์ค์นํ torch_tensorrt๋ฅผ ์ํฌํธ
- pytorch ๋ชจ๋ธ์ ๋ก๋ํ๊ณ torch_tensor.complie ์ํ
- ์ปดํ์ผ ์ input์ ์ ๋ ฅํ๋๋ฐ, ์ด ๋ ๋ชจ๋ธ ์ ๋ ฅ์ ์ต์ ํฌ๊ธฐ, ์ต์ ํ ํฌ๊ธฐ, ์ต๋ ํฌ๊ธฐ๋ฅผ ์ง์
- ์ดํ ๋ณํํ ๋ชจ๋ธ์ input์ ์ ๋ ฅํด ์ฃผ๊ธฐ๋ง ํ๋ฉด ๋จ
- torch.jit.save๋ก ๋ณํํ ๋ชจ๋ธ ์ ์ฅ ๊ฐ๋ฅ
*์ ๋ ฅ ํฌ๊ธฐ๋ฅผ ๋์ ์ผ๋ก ์กฐ์ ํ ์ ์๊ธด ํ์ง๋ง, ๋ชจ๋ธ ์ ํ๋์ ์๋๋ opt_shape (์ต์ ํ๋ ํฌ๊ธฐ)์์๋ง ๋ณด์ฅ ๋ฐ์ ์ ์์. ์ ๋ ฅ ํฌ๊ธฐ๊ฐ opt_shape๊ณผ ์ฐจ์ด๊ฐ ๋ง์ด ๋ ์๋ก pytorch ๋ชจ๋ธ ์ถ๋ ฅ๊ณผ tensorrt๋ก ๋ณํํ ๋ชจ๋ธ์ ์ถ๋ ฅ๊ฐ์ ์ฐจ์ด๊ฐ ๋ง์ด ๋จ.
# Sample using collection-based inputs via the input_signature argument
import torch_tensorrt
...
model = MyModel().eval()
# input_signature expects a tuple of individual input arguments to the module
# The module below, for example, would have a docstring of the form:
# def forward(self, input0: List[torch.Tensor], input1: Tuple[torch.Tensor, torch.Tensor])
input_signature = (
[torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)],
(torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)),
)
enabled_precisions = {torch.float, torch.half}
trt_ts_module = torch_tensorrt.compile(
model, input_signature=input_signature, enabled_precisions=enabled_precisions
)
input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
- input_signature๋ฅผ ์ฌ์ฉํ๋ ์์
- ๋ค์ํ ํํ์ ์ ๋ ฅ์ ๋ฐ์ ์ ์์
# Deployment application
import torch
import torch_tensorrt
trt_ts_module = torch.jit.load("trt_ts_module.ts")
input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
- tensorRT๋ก ๋ณํ & ์ ์ฅํ ๋ชจ๋ธ์ ๋ก๋ํด์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ
# PyTorch ๋ชจ๋ธ ์ถ๋ก
torch_output = model(input_data)
# TensorRT ๋ชจ๋ธ ์ถ๋ก
trt_output = trt_ts_module(input_data)
# ๋ ์ถ๋ ฅ ๊ฐ์ ์ ์ฌ์ฑ ํ์ธ
if torch.allclose(torch_output, trt_output, atol=1e-3):
print("PyTorch์ TensorRT ๋ชจ๋ธ์ ์ถ๋ ฅ์ด ์ ์ฌํฉ๋๋ค.")
else:
print("PyTorch์ TensorRT ๋ชจ๋ธ์ ์ถ๋ ฅ์ด ๋ค๋ฆ
๋๋ค.")
- pytorch ๋ชจ๋ธ๊ณผ tensorrt ๋ชจ๋ธ์ ์ถ๋ ฅ ์ ์ฌ์ฑ ํ ์คํธ