Last active
November 22, 2019 09:38
-
-
Save n0obcoder/6c88c8343ebde778671acc6b5591da79 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import torch.nn.functional as F | |
# Making a 'predict' function which would take the 'model' and the path of the 'test image' as inputs, and predict the class that the test image belongs to. | |
def predict(model, test_img_path): | |
img = cv2.imread(test_img_path) | |
# Visualizing the test image | |
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
img = transforms.Compose([transforms.ToPILImage()])(img) | |
img = data_transforms(img) | |
img = img.view(1, img.shape[0], img.shape[2], img.shape[2]) # Expanding dimension | |
model.eval() | |
with torch.no_grad(): | |
logits = model(img) | |
probs = F.softmax(logits, dim = 1) | |
max_prob, ind = torch.max(probs, 1) | |
print('This Neural Network thinks that the given image belongs to >>> {} <<< class with confidence of {}%'.format(dataset.classes[ind], round(max_prob.item()*100, 2))) | |
test_data_dir = 'mobile_gallery_image_classification_data/mobile_gallery_image_classification/test' | |
test_img_list = [] | |
for class_dir in glob.glob(test_data_dir + os.sep + '*'): | |
test_img_list.append(class_dir) | |
# Loading the trained model(architecture as well as the weights) for making inferences | |
model = torch.load('stage2.pth') | |
# Select the test image index(choose a number from 0 to 6) | |
test_img_index = 3 | |
predict(model, test_img_list[test_img_index]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This Neural Network thinks that the given image belongs to >>> Memes <<< class with confidence of 95.21% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment