Skip to content

Instantly share code, notes, and snippets.

@ashhadulislam
Created July 26, 2022 18:33
Show Gist options
  • Save ashhadulislam/98044ef77944b3e4905ebc810b15aa5c to your computer and use it in GitHub Desktop.
Save ashhadulislam/98044ef77944b3e4905ebc810b15aa5c to your computer and use it in GitHub Desktop.
def predict(model, image_url):
'''
pass the model and image url to the function
Returns: a list of pox types with decreasing probability
'''
if validators.url(image_url) is True:
response = requests.get(image_url)
picture = Image.open(BytesIO(response.content))
else:
picture = Image.open(image_url)
# Convert the image to grayscale and other transforms
image = data_transform(picture)
# store in a list of images
images=image.reshape(1,1,64,64)
new_images = images.repeat(1, 3, 1, 1)
outputs=model(new_images)
# get prediction
_, predicted = torch.max(outputs, 1)
ranked_labels=torch.argsort(outputs,1)[0]
# get all classes in order of probability
probable_classes=[]
for label in ranked_labels:
probable_classes.append(classes[label.numpy()])
probable_classes.reverse()
return probable_classes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment