Skip to content

Instantly share code, notes, and snippets.

@sandeepkumar-skb
Last active August 31, 2020 23:36
Show Gist options
  • Save sandeepkumar-skb/0eabb4adf5c4b115c94cfa5dc24ba16b to your computer and use it in GitHub Desktop.
Save sandeepkumar-skb/0eabb4adf5c4b115c94cfa5dc24ba16b to your computer and use it in GitHub Desktop.
Running inference on Torchvision ResNet50 in TensorRT using Torch2TRT
import torch
from torch2trt import torch2trt
import torchvision.models as models
import tensorrt as trt
import torchvision.transforms as transforms
import torchvision.datasets as datasets
def get_trt_engine(model, inputs, max_batch_size=1, fp16_mode=True, int8_mode=False, int8_calib_dataset=None):
model_trt = torch2trt(model,
[inputs],
max_batch_size=max_batch_size,
fp16_mode=fp16_mode,
int8_mode=int8_mode,
int8_calib_dataset=int8_calib_dataset,
log_level=trt.Logger.INFO,
max_workspace_size=1 << 30,
)
return model_trt
def exp_dummy_inputs(model):
x = torch.ones([1, 3, 224, 224], device='cuda', dtype=torch.float)
calibration_dataset = [torch.randn([1, 3, 224, 224], device='cuda', dtype=torch.float) for _ in range(100)]
model_trt =get_trt_engine(model, x, fp16_mode=True, int8_mode=True, int8_calib_dataset=calibration_dataset)
y = model(x)
y_trt = model_trt(x)
print(torch.max(torch.abs(y-y_trt)))
return
def exp_tiny_imagenet_inputs(model):
val_loader = get_dataset()
val_loader_iter = iter(val_loader)
calibration_dataset = [next(val_loader_iter)[0] for _ in range(10)]
x = torch.ones([1, 3, 224, 224], device='cuda', dtype=torch.float)
model_trt = get_trt_engine(model, x, fp16_mode=True, int8_mode=True, int8_calib_dataset=calibration_dataset)
y_trt = model_trt(x)
y = model(x)
print(torch.max(torch.abs(y-y_trt)))
return
def get_dataset():
val_dir = "tiny-imagenet-200/val"
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
val_dataset = datasets.ImageFolder(val_dir,
transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=256, shuffle=True,
num_workers=8, pin_memory=False)
return val_loader
if __name__ == "__main__":
print("Importing RN18 from torchvision")
model = models.resnet18(pretrained=True)
model.eval().cuda()
exp_dummy_inputs(model)
exp_tiny_imagenet_inputs(model)
@sandeepkumar-skb
Copy link
Author

sandeepkumar-skb commented Aug 28, 2020

Steps to run:

  1. Run NGC PyTorch Docker
    At the time of the script I ran docker run -it --gpus=all --ipc=host -v=<path>:/host nvcr.io/nvidia/pytorch:20.03-py3
  2. Clone and install the Torch2TRT package
    git clone https://github.com/NVIDIA-AI-IOT/torch2trt
    cd torch2trt
    python setup.py install
    
  3. cd ..; copy the script to this location
  4. If you want to run only with dummy data then comment out exp_tiny_imagenet_inputs(model) and run only exp_dummy_inputs(model)
  5. If you want to run with Tiny Imagenet then you can do the following:
    5a. wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
    5b. unzip tiny-imagenet-200.zip
    5c. Update the path in the script if stored anywhere else.
  6. Run the script
    python Torch2TRT_resnet50.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment