Created
September 25, 2020 04:32
-
-
Save giangnguyen2412/ede6391ccb7b0328ca6aa14e03a0a479 to your computer and use it in GitHub Desktop.
Convert pretrained Resnet50 model from Pytorch to Tensorflow using ONNX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## Run this on python CLI | |
import torch | |
import torchvision | |
torch.set_num_threads(1) | |
from torchvision.models import * | |
from visualisation.core.utils import device, image_net_postprocessing | |
from torch import nn | |
from operator import itemgetter | |
from visualisation.core.utils import imshow | |
import glob | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from utils import * | |
from PIL import Image | |
test_image_paths = glob.glob('/home/dexter/Downloads/Exp1/Exp1-1 100/*.*') | |
from torchvision.transforms import ToTensor, Resize, Compose, ToPILImage | |
from visualisation.core import * | |
from visualisation.core.utils import image_net_preprocessing | |
size= 224 | |
# Pre-process the image and convert into a tensor | |
transform = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize(size), | |
torchvision.transforms.CenterCrop(size), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
trained_model = resnet50(pretrained=True) | |
img = Image.open(test_image_paths[8]) | |
x = transform(img).unsqueeze(0) | |
torch.onnx.export(trained_model, x, 'resnet50.onnx') | |
## Disable CLI and run | |
import onnx | |
from onnx_tf.backend import prepare | |
model = onnx.load('resnet50.onnx') | |
tf_rep = prepare(model) | |
import torch | |
import torchvision | |
torch.set_num_threads(1) | |
from torchvision.models import * | |
from visualisation.core.utils import device, image_net_postprocessing | |
from torch import nn | |
from operator import itemgetter | |
from visualisation.core.utils import imshow | |
import glob | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from utils import * | |
from PIL import Image | |
# You dir to test image here. We expect the original model and the converted model give the same label on the same images | |
test_image_paths = glob.glob('/home/dexter/Downloads/Exp1/Exp1-1 100/*.*') | |
from torchvision.transforms import ToTensor, Resize, Compose, ToPILImage | |
from visualisation.core import * | |
from visualisation.core.utils import image_net_preprocessing | |
size= 224 | |
# Pre-process the image and convert into a tensor | |
transform = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize(size), | |
torchvision.transforms.CenterCrop(size), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]), | |
]) | |
model = resnet50(pretrained=True).to(device) | |
model.eval() | |
# You test image | |
img = Image.open(test_image_paths[10]) | |
x = transform(img).unsqueeze(0).to(device) | |
out = model(x) | |
p = torch.nn.functional.softmax(out, dim=1) | |
score, index = torch.topk(p, 1) | |
input_category_id = index[0][0].item() # 716 | |
predicted_confidence = score[0][0].item() | |
## Now run the converted model | |
output = tf_rep.run(x.cpu()) | |
np.argmax(output) # 716 | |
tf_rep.export_graph('resnet50.pb') # Save the model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment