Torch-TensorRT๋ PyTorch์ NVIDIA์ TensorRT๋ฅผ ํตํฉํ์ฌ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ์ต์ ํํ๊ณ ๊ฐ์ํํ๋ ๋ฐ ์ฌ์ฉ๋๋ PyTorch/TorchScript/FX์ฉ ์ปดํ์ผ๋ฌ์ด๋ค. Torch-TensorRT๋ PyTorch ์ต์คํ ์ ์ผ๋ก ๋์ํ๋ฉฐ JIT(Just In Time) ๋ฐํ์์ ์ํ ํ๊ฒ ํตํฉ๋๋ ๋ชจ๋์ ์ปดํ์ผํ๋ค.
NVIDIA TensorRT๋ NVIDIA GPU์์ ๋ชจ๋ธ์ ๋ ๋น ๋ฅด๊ฒ ์คํํ๊ธฐ ์ํ ์ต์ ํ๋ ๋ฐํ์ ์์ง์ผ๋ก, ํนํ ๋ฅ ๋ฌ๋ ๋ชจ๋ธ์ ๋ฐฐํฌ ํ๊ฒฝ์์ ๋ ํจ์จ์ ์ผ๋ก ์คํํ๊ณ ์ถ๋ก (inference) ์ฑ๋ฅ์ ํฅ์์ํค๋ ๋ฐ ์ฌ์ฉ๋๋ค.
๊ธฐ์กด์ ํ์ด์ฌ์ผ๋ก TensorRT๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด์๋ ์คํ์์ค ์ปค๋ฎค๋ํฐ์์ ๊ฐ๋ฐํ torch2trt ํจํค์ง๋ฅผ ์ฌ์ฉํด์ pytorch ๋ชจ๋ธ์ tensorRT ํธํ ํ์์ผ๋ก ๋ณํํด์ ๋ชจ๋ธ ์ธํผ๋ฐ์ค๋ฅผ ๊ฐ์ํ์์ผฐ๋ค. ํ์ง๋ง NIVIDA์ PyTorch๊ฐ ๊ณต์์ ์ผ๋ก ์ ๊ณตํ๋ Torch-TensorRT๋ฅผ ์ฌ์ฉํ๋ฉด PyTorch ๋ชจ๋ธ์ ๋ณํํ ๋ ์ต์ ํ ์์ค์ ๋ ์ธ๋ฐํ๊ฒ ์ ์ดํ ์ ์๋ค.
Torch-TensorRT๋ ๋ค์๊ณผ ๊ฐ์ ์ฃผ์ ๊ธฐ๋ฅ์ ์ ๊ณตํ๋ค
- ๋ชจ๋ธ ์ต์ ํ: Torch-TensorRT๋ PyTorch ๋ชจ๋ธ์ TensorRT ํธํ ํ์์ผ๋ก ๋ณํํ์ฌ TensorRT์ ์ต์ ํ ๊ธฐ๋ฅ์ ํ์ฉํ ์ ์์. ์ด๋ก์จ ๋ชจ๋ธ์ ์ฐ์ฐ ๊ทธ๋ํ๊ฐ ์ต์ ํ๋๊ณ , GPU ๊ฐ์์ ์ํ ํน์ ์ฐ์ฐ ์ปค๋๋ก ๋์ฒด๋จ.
- TensorRT ๋ชจ๋ธ ์์ฑ: Torch-TensorRT๋ฅผ ์ฌ์ฉํ์ฌ ์ต์ ํ๋ ๋ชจ๋ธ์ ์์ฑํ ์ ์์. ์ด๋ ๊ฒ ์์ฑ๋ ๋ชจ๋ธ์ TensorRT ์์ง์ ์ฌ์ฉํ์ฌ ํจ์จ์ ์ผ๋ก ์ถ๋ก ์ ์ํํ ์ ์์.
- End-to-end ๊ฐ์ํ: Torch-TensorRT๋ฅผ ์ฌ์ฉํ๋ฉด PyTorch ๋ชจ๋ธ์ ๊ฐ์ ธ์ TensorRT ๋ชจ๋ธ๋ก ๋ณํํ๊ณ ์ถ๋ก ์ ์ํํ๋ ์ ์ฒด ์๋ ํฌ ์๋ ํ์ดํ๋ผ์ธ์ ๊ตฌ์ถํ ์ ์์.
- ๋์ ๊ทธ๋ํ ์ง์: PyTorch์ ๋์ ๊ทธ๋ํ ํน์ฑ์ ์ ์งํ๋ฉด์ TensorRT๋ก์ ์ต์ ํ๊ฐ ๊ฐ๋ฅํ๋ฏ๋ก, ๋ค์ํ ๋ชจ๋ธ ์ํคํ ์ฒ์ ๋ํ ์ง์์ ์ ๊ณตํฉ๋๋ค.
Torch-TensorRT๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฅ ๋ฌ๋ ๋ชจ๋ธ์ NVIDIA GPU์์ ๋ ํจ์จ์ ์ผ๋ก ์คํํ๊ณ , ์ถ๋ก ์๋๋ฅผ ํฌ๊ฒ ํฅ์์ํฌ ์ ์์ผ๋ฉฐ, ์ด๋ ์ค์๊ฐ ์๊ตฌ ์ฌํญ์ด ์๋ ์์ฉ ํ๋ก๊ทธ๋จ ๋ฐ ์๋น์ค์ ํนํ ์ ์ฉํ๋ค.
๋ฒ์ ์ด ๋ง๋ pytorch๊ฐ ์ค์น๋์ด ์๋ค๋ฉด tensorrt, torch_tensorrt ๋ง ์ค์นํ๋ฉด ์์ ์ฝ๋๋ฅผ ์คํํด ๋ณผ ์ ์๋ค.
pip install tensorrt torch_tensorrt
ํ์ง๋ง ํ๋ฒ์ ์ ๋๋ก ๋์ํ์ง ์๋ ๊ฒฝ์ฐ๋ ๋ง๋ค... ใ TensorRT๋ ํ๊ฒฝ ๊ตฌ์ฑ์ด ์๊ฐ๋ณด๋ค ํ๋ค๊ธฐ ๋๋ฌธ์ ์ ๋์ํ๋ ํจํค์ง ๋ฒ์ ์กฐํฉ์ ์ ํ์ธํด์ ์ค์นํด์ผ ํ๋ค. Torch-TensorRT ๊นํ ๋ฆด๋ฆฌ์ค ํ์ด์ง์์ Torch-TensorRT ๋ฒ์ ๋ณ pytorch, cuda, tensorrt์์ ๋ฒ์ ํธํ์ฑ์ ํ์ธํ ์ ์๋ค.
→ https://github.com/pytorch/TensorRT/releases
๊ฐ์ธ์ ์ผ๋ก๋ TensorRT 8.6.1์ด ์ค์น๋ ์ธ์คํด์ค๋ฅผ ๋์์ pytorch 2.0 cuda 11.8 ์ ์ค์นํ ๋ค์ torch-tensorrt v1.4๋ฅผ ์ค์นํ๋ค. ๊ทธ๋ฌ๋๋... torchvision์ ๋ฒ์ ์ด ํธํ์ด ์๋๋ค๊ณ ํด์ torchvision๋ง ์ง์ ๋ค๊ฐ ๋ค์ ์ค์นํ๋๊น ์ ๋์ํ๋ค.
Python Torch-TensorRT ์ฌ์ฉ ์์
Torch-TensorRT๋ก ์ ๋ ฅ์ ์ปดํ์ผํ๋ ค๋ฉด torch.nn.ModuleTorch-TensorRT์ ๋ชจ๋๊ณผ ์ ๋ ฅ์ ์ ๊ณตํ๋ฉด ๋๋ค. ๊ทธ๋ฌ๋ฉด ๋ค๋ฅธ PyTorch ๋ชจ๋์ ์คํํ๊ฑฐ๋ ์ถ๊ฐํ ์ ์๋๋ก ์ต์ ํ๋ TorchScript ๋ชจ๋์ด ๋ฐํ๋๋ค.
์ ๋ ฅ์ ์ ๋ ฅ ํ ์์ ๋ชจ์, ๋ฐ์ดํฐ ์ ํ ๋ฐ ๋ฉ๋ชจ๋ฆฌ ํ์์ ์ ์ํ๋ torch_tensorrt.Input ํด๋์ค list์ด๋ค. ์ ๋ ฅ์ด ํํ์ด๋ Tensor ๋ชฉ๋ก๊ณผ ๊ฐ์ ๋ ๋ณต์กํ ๋ฐ์ดํฐ ์ ํ์ธ ๊ฒฝ์ฐ ์ธ์๋ฅผ ์ฌ์ฉํ์ฌ input_signature์ ๊ฐ์ ์ปฌ๋ ์ ๊ธฐ๋ฐ ์ ๋ ฅ(e.g. List[Tensor], Tuple[Tensor, Tensor])์ ์ง์ ํ ์ ์๋ค(๋ ๋ฒ์งธ ์์ ์ฝ๋).
์์ง์ด๋ ๋์ ์ฅ์น์ ์๋ ์ ๋ฐ๋์ ๊ฐ์ ์ค์ ์ ์ง์ ํ ์๋ ์๊ณ , ์ปดํ์ผ ํ์๋ ๋ค๋ฅธ ๋ชจ๋๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ๋ชจ๋์ ์ ์ฅํ์ฌ ๋ฐฐํฌ ์ ํ๋ฆฌ์ผ์ด์ ์ ๋ก๋ํ ์ ์๋ค. TensorRT/TorchScript ๋ชจ๋์ ๋ก๋ํ๋ ค๋ฉด ๋จผ์ torch_tensorrt๋ฅผ ์ํฌํธ ํด์ผ ํ๋ค.
import torch_tensorrt
...
model = MyModel().eval() # torch module needs to be in eval (not training) mode
inputs = [
torch_tensorrt.Input(
min_shape=[1, 1, 16, 16],
opt_shape=[1, 1, 32, 32],
max_shape=[1, 1, 64, 64],
dtype=torch.half,
)
]
enabled_precisions = {torch.float, torch.half} # Run with fp16
trt_ts_module = torch_tensorrt.compile(
model, inputs=inputs, enabled_precisions=enabled_precisions
)
input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
- ์ค์นํ torch_tensorrt๋ฅผ ์ํฌํธ
- pytorch ๋ชจ๋ธ์ ๋ก๋ํ๊ณ torch_tensor.complie ์ํ
- ์ปดํ์ผ ์ input์ ์ ๋ ฅํ๋๋ฐ, ์ด ๋ ๋ชจ๋ธ ์ ๋ ฅ์ ์ต์ ํฌ๊ธฐ, ์ต์ ํ ํฌ๊ธฐ, ์ต๋ ํฌ๊ธฐ๋ฅผ ์ง์
- ์ดํ ๋ณํํ ๋ชจ๋ธ์ input์ ์ ๋ ฅํด ์ฃผ๊ธฐ๋ง ํ๋ฉด ๋จ
- torch.jit.save๋ก ๋ณํํ ๋ชจ๋ธ ์ ์ฅ ๊ฐ๋ฅ
*์ ๋ ฅ ํฌ๊ธฐ๋ฅผ ๋์ ์ผ๋ก ์กฐ์ ํ ์ ์๊ธด ํ์ง๋ง, ๋ชจ๋ธ ์ ํ๋์ ์๋๋ opt_shape (์ต์ ํ๋ ํฌ๊ธฐ)์์๋ง ๋ณด์ฅ ๋ฐ์ ์ ์์. ์ ๋ ฅ ํฌ๊ธฐ๊ฐ opt_shape๊ณผ ์ฐจ์ด๊ฐ ๋ง์ด ๋ ์๋ก pytorch ๋ชจ๋ธ ์ถ๋ ฅ๊ณผ tensorrt๋ก ๋ณํํ ๋ชจ๋ธ์ ์ถ๋ ฅ๊ฐ์ ์ฐจ์ด๊ฐ ๋ง์ด ๋จ.
# Sample using collection-based inputs via the input_signature argument
import torch_tensorrt
...
model = MyModel().eval()
# input_signature expects a tuple of individual input arguments to the module
# The module below, for example, would have a docstring of the form:
# def forward(self, input0: List[torch.Tensor], input1: Tuple[torch.Tensor, torch.Tensor])
input_signature = (
[torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)],
(torch_tensorrt.Input(shape=[64, 64], dtype=torch.half), torch_tensorrt.Input(shape=[64, 64], dtype=torch.half)),
)
enabled_precisions = {torch.float, torch.half}
trt_ts_module = torch_tensorrt.compile(
model, input_signature=input_signature, enabled_precisions=enabled_precisions
)
input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")
- input_signature๋ฅผ ์ฌ์ฉํ๋ ์์
- ๋ค์ํ ํํ์ ์ ๋ ฅ์ ๋ฐ์ ์ ์์
# Deployment application
import torch
import torch_tensorrt
trt_ts_module = torch.jit.load("trt_ts_module.ts")
input_data = input_data.to("cuda").half()
result = trt_ts_module(input_data)
- tensorRT๋ก ๋ณํ & ์ ์ฅํ ๋ชจ๋ธ์ ๋ก๋ํด์ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ
# PyTorch ๋ชจ๋ธ ์ถ๋ก
torch_output = model(input_data)
# TensorRT ๋ชจ๋ธ ์ถ๋ก
trt_output = trt_ts_module(input_data)
# ๋ ์ถ๋ ฅ ๊ฐ์ ์ ์ฌ์ฑ ํ์ธ
if torch.allclose(torch_output, trt_output, atol=1e-3):
print("PyTorch์ TensorRT ๋ชจ๋ธ์ ์ถ๋ ฅ์ด ์ ์ฌํฉ๋๋ค.")
else:
print("PyTorch์ TensorRT ๋ชจ๋ธ์ ์ถ๋ ฅ์ด ๋ค๋ฆ
๋๋ค.")
- pytorch ๋ชจ๋ธ๊ณผ tensorrt ๋ชจ๋ธ์ ์ถ๋ ฅ ์ ์ฌ์ฑ ํ ์คํธ