Skip to content

Instantly share code, notes, and snippets.

@sandeepkumar-skb
Created August 25, 2020 05:12
Show Gist options
  • Save sandeepkumar-skb/040c115f4ffcfde118cd456eb574b0de to your computer and use it in GitHub Desktop.
Save sandeepkumar-skb/040c115f4ffcfde118cd456eb574b0de to your computer and use it in GitHub Desktop.
Sample script to run INT8 in TensorRT with calibration data on Tiny-Imagenet
import torch
from torch2trt import torch2trt
import torchvision.models as models
import tensorrt as trt
import torchvision.transforms as transforms
import torchvision.datasets as datasets
def get_trt_engine(model, inputs, max_batch_size=1, fp16_mode=True, int8_mode=False, int8_calib_dataset=None):
model_trt = torch2trt(model,
[inputs],
max_batch_size=max_batch_size,
fp16_mode=fp16_mode,
int8_mode=int8_mode,
int8_calib_dataset=int8_calib_dataset,
log_level=trt.Logger.INFO,
max_workspace_size=1 << 30,
)
return model_trt
def exp_dummy_inputs(model):
x = torch.ones([1, 3, 224, 224], device='cuda', dtype=torch.float)
calibration_dataset = [torch.randn([1, 3, 224, 224], device='cuda', dtype=torch.float) for _ in range(100)]
model_trt =get_trt_engine(model, x, fp16_mode=True, int8_mode=True, int8_calib_dataset=calibration_dataset)
y = model(x)
y_trt = model_trt(x)
print(torch.max(torch.abs(y-y_trt)))
return
def exp_tiny_imagenet_inputs(model):
val_loader = get_dataset()
val_loader_iter = iter(val_loader)
calibration_dataset = [next(val_loader_iter)[0] for _ in range(10)]
x = torch.ones([1, 3, 224, 224], device='cuda', dtype=torch.float)
model_trt = get_trt_engine(model, x, fp16_mode=True, int8_mode=True, int8_calib_dataset=calibration_dataset)
y_trt = model_trt(x)
y = model(x)
print(torch.max(torch.abs(y-y_trt)))
return
def get_dataset():
val_dir = "tiny-imagenet-200/val"
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
val_dataset = datasets.ImageFolder(val_dir,
transforms.Compose([
transforms.Resize(256),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=256, shuffle=True,
num_workers=8, pin_memory=False)
return val_loader
if __name__ == "__main__":
print("Importing RN18 from torchvision")
model = models.resnet18(pretrained=True)
model.eval().cuda()
#exp_dummy_inputs(model)
exp_tiny_imagenet_inputs(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment