Created
August 25, 2020 05:12
-
-
Save sandeepkumar-skb/040c115f4ffcfde118cd456eb574b0de to your computer and use it in GitHub Desktop.
Sample script to run INT8 in TensorRT with calibration data on Tiny-Imagenet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch2trt import torch2trt | |
import torchvision.models as models | |
import tensorrt as trt | |
import torchvision.transforms as transforms | |
import torchvision.datasets as datasets | |
def get_trt_engine(model, inputs, max_batch_size=1, fp16_mode=True, int8_mode=False, int8_calib_dataset=None): | |
model_trt = torch2trt(model, | |
[inputs], | |
max_batch_size=max_batch_size, | |
fp16_mode=fp16_mode, | |
int8_mode=int8_mode, | |
int8_calib_dataset=int8_calib_dataset, | |
log_level=trt.Logger.INFO, | |
max_workspace_size=1 << 30, | |
) | |
return model_trt | |
def exp_dummy_inputs(model): | |
x = torch.ones([1, 3, 224, 224], device='cuda', dtype=torch.float) | |
calibration_dataset = [torch.randn([1, 3, 224, 224], device='cuda', dtype=torch.float) for _ in range(100)] | |
model_trt =get_trt_engine(model, x, fp16_mode=True, int8_mode=True, int8_calib_dataset=calibration_dataset) | |
y = model(x) | |
y_trt = model_trt(x) | |
print(torch.max(torch.abs(y-y_trt))) | |
return | |
def exp_tiny_imagenet_inputs(model): | |
val_loader = get_dataset() | |
val_loader_iter = iter(val_loader) | |
calibration_dataset = [next(val_loader_iter)[0] for _ in range(10)] | |
x = torch.ones([1, 3, 224, 224], device='cuda', dtype=torch.float) | |
model_trt = get_trt_engine(model, x, fp16_mode=True, int8_mode=True, int8_calib_dataset=calibration_dataset) | |
y_trt = model_trt(x) | |
y = model(x) | |
print(torch.max(torch.abs(y-y_trt))) | |
return | |
def get_dataset(): | |
val_dir = "tiny-imagenet-200/val" | |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
val_dataset = datasets.ImageFolder(val_dir, | |
transforms.Compose([ | |
transforms.Resize(256), | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
normalize, | |
])) | |
val_loader = torch.utils.data.DataLoader( | |
val_dataset, batch_size=256, shuffle=True, | |
num_workers=8, pin_memory=False) | |
return val_loader | |
if __name__ == "__main__": | |
print("Importing RN18 from torchvision") | |
model = models.resnet18(pretrained=True) | |
model.eval().cuda() | |
#exp_dummy_inputs(model) | |
exp_tiny_imagenet_inputs(model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment