Skip to content

Instantly share code, notes, and snippets.

@TsaiKoga
Created December 12, 2019 05:58
Show Gist options
  • Save TsaiKoga/21f28856ecf95dad7f8de6d626793304 to your computer and use it in GitHub Desktop.
Save TsaiKoga/21f28856ecf95dad7f8de6d626793304 to your computer and use it in GitHub Desktop.
PyTorch Helper: Plot the MNIST image and percentage graph
import matplotlib.pyplot as plt
import numpy as np
def view_classify(img, ps, version="MNIST"):
''' Function for viewing an image and it's predicted classes.
'''
ps = ps.data.numpy().squeeze()
fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze())
ax1.axis('off')
ax2.barh(np.arange(10), ps)
ax2.set_aspect(0.1)
ax2.set_yticks(np.arange(10))
if version == "MNIST":
ax2.set_yticklabels(np.arange(10))
elif version == "Fashion":
ax2.set_yticklabels(['T-shirt/top',
'Trouser',
'Pullover',
'Dress',
'Coat',
'Sandal',
'Shirt',
'Sneaker',
'Bag',
'Ankle Boot'], size='small');
ax2.set_title('Class Probability')
ax2.set_xlim(0, 1.1)
plt.tight_layout()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment