Skip to content

Instantly share code, notes, and snippets.

@michelkana
Created July 18, 2019 19:51
Show Gist options
  • Save michelkana/c8a86743440cc54de9666d591abc6ca4 to your computer and use it in GitHub Desktop.
Save michelkana/c8a86743440cc54de9666d591abc6ca4 to your computer and use it in GitHub Desktop.
# plot de-noised images
def plot_mnist_predict(x_test, x_test_noisy, autoencoder, y_test, labels=[]):
if len(labels)>0:
x_test = x_test[np.isin(y_test, labels)]
x_test_noisy = x_test_noisy[np.isin(y_test, labels)]
decoded_imgs = autoencoder.predict(x_test)
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
ax = plt.subplot(2, n, i + 1)
plt.imshow(x_test_noisy[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
return decoded_imgs, x_test
decoded_imgs_test, x_test_new = plot_mnist_predict(x_test, x_test_noisy, autoencoder, y_test)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment