Skip to content

Instantly share code, notes, and snippets.

@michelkana
Last active July 18, 2019 19:40
Show Gist options
  • Save michelkana/58141929bafc67abff74533537e71194 to your computer and use it in GitHub Desktop.
Save michelkana/58141929bafc67abff74533537e71194 to your computer and use it in GitHub Desktop.
import matplotlib.pyplot as plt
import random
%matplotlib inline
# get MNIST images, clean and with noise
def get_mnist(noise_factor=0.5):
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))
x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape)
x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape)
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)
return x_train, x_test, x_train_noisy, x_test_noisy, y_train, y_test
x_train, x_test, x_train_noisy, x_test_noisy, y_train, y_test = get_mnist()
# plot n random digits
# use labels to specify which digits to plot
def plot_mnist(x, y, n=10, randomly=False, labels=[]):
plt.figure(figsize=(20, 2))
if len(labels)>0:
x = x[np.isin(y, labels)]
for i in range(1,n,1):
ax = plt.subplot(1, n, i)
if randomly:
j = random.randint(0,x.shape[0])
else:
j = i
plt.imshow(x[j].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
plot_mnist(x_test_noisy, y_test, randomly=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment