This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model.compile(optimizer='adam', | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
metrics=['accuracy']) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model.fit(train_images, train_labels, epochs=10) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2) | |
print('\nTest accuracy:', test_acc) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
probability_model = tf.keras.Sequential([model, | |
tf.keras.layers.Softmax()]) | |
predictions = probability_model.predict(test_images) | |
# Here, the model has predicted the label for each image in the testing set. Let's take a look at the first prediction: | |
predictions[0] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_image(i, predictions_array, true_label, img): | |
true_label, img = true_label[i], img[i] | |
plt.grid(False) | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.imshow(img, cmap=plt.cm.binary) | |
predicted_label = np.argmax(predictions_array) | |
if predicted_label == true_label: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
i = 0 | |
plt.figure(figsize=(6,3)) | |
plt.subplot(1,2,1) | |
plot_image(i, predictions[i], test_labels, test_images) | |
plt.subplot(1,2,2) | |
plot_value_array(i, predictions[i], test_labels) | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
i = 12 | |
plt.figure(figsize=(6,3)) | |
plt.subplot(1,2,1) | |
plot_image(i, predictions[i], test_labels, test_images) | |
plt.subplot(1,2,2) | |
plot_value_array(i, predictions[i], test_labels) | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Plot the first X test images, their predicted labels, and the true labels. | |
# Color correct predictions in blue and incorrect predictions in red. | |
num_rows = 5 | |
num_cols = 3 | |
num_images = num_rows*num_cols | |
plt.figure(figsize=(2*2*num_cols, 2*num_rows)) | |
for i in range(num_images): | |
plt.subplot(num_rows, 2*num_cols, 2*i+1) | |
plot_image(i, predictions[i], test_labels, test_images) | |
plt.subplot(num_rows, 2*num_cols, 2*i+2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
predictions_single = probability_model.predict(img) | |
print(predictions_single) | |
[[8.8914348e-05 1.3264636e-13 9.9108773e-01 1.2658383e-10 8.1463791e-03 | |
1.6905785e-08 6.7695131e-04 2.7492119e-17 5.1699739e-10 7.1339325e-17]] | |
plot_value_array(1, predictions_single[0], test_labels) | |
_ = plt.xticks(range(10), class_names, rotation=45) | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
np.argmax(predictions_single[0]) |