Skip to content

Instantly share code, notes, and snippets.

@vicradon
Last active December 28, 2024 05:49
Show Gist options
  • Save vicradon/eec2bb235ec4043ea60c08ffdac44e6e to your computer and use it in GitHub Desktop.
Save vicradon/eec2bb235ec4043ea60c08ffdac44e6e to your computer and use it in GitHub Desktop.
Pytorch model inference using the eval mode
import torch
from torchvision import models
# single image batch size, 3 color channels, height, and width
dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True)
resnet = models.resnet50()
resnet.fc = torch.nn.Linear(in_features=2048, out_features=102)
resnet.load_state_dict(torch.load("model.pth", weights_only=True, map_location=torch.device('cpu')))
torch.onnx.export(resnet, dummy_input, "model.onnx")
import torch
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import argparse
def load_and_prepare_image(image_path):
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
input_tensor = preprocess(image).unsqueeze(0)
return input_tensor
def main():
resnet = models.resnet50()
resnet.fc = torch.nn.Linear(in_features=2048, out_features=102)
resnet.load_state_dict(torch.load("model.pth", weights_only=True, map_location=torch.device('cpu')))
resnet.eval()
parser = argparse.ArgumentParser(description="Classify an image using a ResNet model.")
parser.add_argument("image_path", type=str, help="Path to the image to classify")
args = parser.parse_args()
tensor_image = load_and_prepare_image(args.image_path)
with torch.no_grad():
output = resnet(tensor_image)
class_distribution = torch.nn.functional.softmax(output, dim=1).squeeze() # Apply softmax for probability distribution
predicted_class = torch.argmax(class_distribution).item()
print(f"Predicted class: {predicted_class}")
if __name__ == "__main__":
main()
import onnxruntime
import numpy as np
from PIL import Image
import argparse
from torchvision import transforms
def load_and_prepare_image(image_path):
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
img_tensor = preprocess(image)
return img_tensor.numpy()[np.newaxis, :]
def main():
parser = argparse.ArgumentParser(description="Classify an image using ONNX model.")
parser.add_argument("image_path", type=str, help="Path to the image to classify")
args = parser.parse_args()
session = onnxruntime.InferenceSession("model.onnx", providers=['CPUExecutionProvider'])
input_name = session.get_inputs()[0].name
input_data = load_and_prepare_image(args.image_path)
outputs = session.run(None, {input_name: input_data})
scores = outputs[0][0]
exp_scores = np.exp(scores - np.max(scores))
probabilities = exp_scores / exp_scores.sum()
predicted_class = np.argmax(probabilities)
print(f"Predicted class: {predicted_class}")
print(f"Confidence: {probabilities[predicted_class]:.4f}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment