• Stars
    star
    254
  • Rank 160,264 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created over 4 years ago
  • Updated 10 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A pytorch to tensorrt convert with dynamic shape support

torch2trt dynamic

This is a branch of torch2trt with dynamic input support

Note that not all layers support dynamic input such as torch.split() etc...

Usage

Here are some examples

Convert

from torch2trt_dynamic import torch2trt_dynamic
import torch
from torch import nn
from torchvision.models.resnet import resnet50

# create some regular pytorch model...
model = resnet50().cuda().eval()

# create example data
x = torch.ones((1, 3, 224, 224)).cuda()

# convert to TensorRT feeding sample data as input
opt_shape_param = [
    [
        [1, 3, 128, 128],   # min
        [1, 3, 256, 256],   # opt
        [1, 3, 512, 512]    # max
    ]
]
model_trt = torch2trt_dynamic(model, [x], fp16_mode=False, opt_shape_param=opt_shape_param)

Execute

We can execute the returned TRTModule just like the original PyTorch model

x = torch.rand(1,3,256,256).cuda()
with torch.no_grad():
    y = model(x)
    y_trt = model_trt(x)

# check the output against PyTorch
print(torch.max(torch.abs(y - y_trt)))

Save and load

We can save the model as a state_dict.

torch.save(model_trt.state_dict(), 'alexnet_trt.pth')

We can load the saved model into a TRTModule

from torch2trt_dynamic import TRTModule

model_trt = TRTModule()

model_trt.load_state_dict(torch.load('alexnet_trt.pth'))

Setup

To install without compiling plugins, call the following

git clone https://github.com/grimoire/torch2trt_dynamic.git torch2trt_dynamic
cd torch2trt_dynamic
python setup.py develop

Set plugins(optional)

Some layers such as GN need c++ plugins. Install the plugin project below

amirstan_plugin

DO NOT FORGET to export the environment variable AMIRSTAN_LIBRARY_PATH

How to add (or override) a converter

Here we show how to add a converter for the ReLU module using the TensorRT Python API.

import tensorrt as trt
from torch2trt_dynamic import tensorrt_converter

@tensorrt_converter('torch.nn.ReLU.forward')
def convert_ReLU(ctx):
    input = ctx.method_args[1]
    output = ctx.method_return
    layer = ctx.network.add_activation(input=input._trt, type=trt.ActivationType.RELU)
    output._trt = layer.get_output(0)

The converter takes one argument, a ConversionContext, which will contain the following

  • ctx.network - The TensorRT network that is being constructed.

  • ctx.method_args - Positional arguments that were passed to the specified PyTorch function. The _trt attribute is set for relevant input tensors.

  • ctx.method_kwargs - Keyword arguments that were passed to the specified PyTorch function.

  • ctx.method_return - The value returned by the specified PyTorch function. The converter must set the _trt attribute where relevant.

Please see this folder for more examples.