Skip to content

Instantly share code, notes, and snippets.

@buttercutter
Last active December 23, 2018 12:54
Show Gist options
  • Save buttercutter/186438ad7d2e6c1ec15d9ce5d8435c13 to your computer and use it in GitHub Desktop.
Save buttercutter/186438ad7d2e6c1ec15d9ce5d8435c13 to your computer and use it in GitHub Desktop.
Python inference code for pruned SqueezeNet model
#Modified from https://github.com/amrit-das/Custom-Model-Training-PyTorch/blob/master/predict.py
import torch
import torch.nn as nn
#from torchvision.models import resnet18
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import torch.functional as F
from PIL import Image
import os
import sys
import argparse
from prune import *
from finetune import *
parser = argparse.ArgumentParser(description = 'To Predict from a trained model')
parser.add_argument('-i','--image', dest = 'image_name', required = True, help='Path to the image file')
parser.add_argument('-m','--model', dest = 'model_name', required = True, help='Path to the model')
parser.add_argument('-n','--num_class',dest = 'num_classes', required = True, help='Number of training classes')
args = parser.parse_args()
#model=ModifiedSqueezeNetModel().cuda()
#model = torch.load(args.model_name).cuda()
#model = resnet18(num_classes = int(args.num_classes))
path_to_model = "./"+args.model_name
#checkpoint = torch.load(path_to_model)
model = torch.load(path_to_model)
#model.load_state_dict(checkpoint)
#model.eval()
def predict_image(image_path):
print("prediction in progress")
image = Image.open(image_path)
transformation = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_tensor = transformation(image).float()
image_tensor = image_tensor.unsqueeze_(0)
if torch.cuda.is_available():
image_tensor.cuda()
input = Variable(image_tensor).cuda()
output = model(input)
#index = output.argmax()
#print("output = ", output)
max_value, max_index = torch.max(output,1)
return max_index.item()
def class_mapping(index):
mapping=open('class_mapping.txt','r')
class_map={}
for line in mapping:
l=line.strip('\n').split('=')
class_map[l[1]]=l[0]
#print("l[0] = ", l[0])
#print("l[1] = ", l[1])
#print("class_map[0] = ", class_map[str(0)])
#print("class_map[1] = ", class_map[str(1)])
return class_map[str(index)]
if __name__ == "__main__":
imagepath = "./test/Lemon/"+args.image_name
prediction = predict_image(imagepath)
#print("prediction = ", str(prediction))
name = class_mapping(prediction)
print("Predicted Class: ",name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment