Skip to content

Instantly share code, notes, and snippets.

@satr
Created May 4, 2019 08:18
Show Gist options
  • Save satr/66b6a5f42a3e18904ea4ae1fb8b3b5f4 to your computer and use it in GitHub Desktop.
Save satr/66b6a5f42a3e18904ea4ae1fb8b3b5f4 to your computer and use it in GitHub Desktop.
Learn image recognition with fashion_mnist
# Based on https://www.coursera.org/learn/introduction-tensorflow/home/welcome
# Course 1 - Part 4 - Lesson 2 - Notebook
#https://github.com/zalandoresearch/fashion-mnist
import tensorflow as tf
from tensorflow import keras
mnist = keras.datasets.fashion_mnist
(training_images, training_labels), (test_images, test_labels) = mnist.load_data()
import matplotlib.pyplot as plt
#plt.imshow(training_images[42])
#print(training_labels[42])
#print(training_images[42])
training_images = training_images / 255.0
test_images = test_images / 255.0
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation=tf.nn.relu), #128 neurones
tf.keras.layers.Dense(10, activation=tf.nn.softmax)]) #10 classes
model.compile(optimizer=tf.train.AdamOptimizer(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(training_images, training_labels, epochs=5)
model.evaluate(test_images, test_labels)
classifications = model.predict(test_images)
classes = {
0: "T-shirt/top",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle boot"
}
show_items = 10
_, axarr = plt.subplots(1,show_items)
for index in range(show_items):
axarr[index].imshow(test_images[index])
#print(classifications[index])
#print(test_labels[index])
predictedClass=classifications.argmax(axis=1)[index]
print(classes[predictedClass])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment