Skip to content

Instantly share code, notes, and snippets.

@aoirint
Created March 20, 2020 03:37
Show Gist options
  • Save aoirint/40b3ac5e08b36932fdd99663c2bb3bf9 to your computer and use it in GitHub Desktop.
Save aoirint/40b3ac5e08b36932fdd99663c2bb3bf9 to your computer and use it in GitHub Desktop.
torch jit test
import torch
import torchvision.models as M
model = M.resnet18(pretrained=True)
model.eval()
# model = M.resnet50(pretrained=True)
dummy = torch.randn((1, 3, 224, 224))
trmodel = torch.jit.trace(model, dummy)
file = 'resnet18.pt'
trmodel.save(file)
# trmodel.save('resnet50.pt')
torch.jit.load(file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment