Created
July 10, 2020 14:02
-
-
Save jakelevi1996/3a390ca8206518e6c4f7aabd25b6c268 to your computer and use it in GitHub Desktop.
Single script example of training a CNN on MNIST using tensorflow.keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from time import perf_counter | |
import numpy as np | |
from tensorflow.keras import datasets, layers, models, losses | |
# Get training and test data | |
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data() | |
x_train = np.expand_dims(x_train, -1) / 255.0 | |
x_test = np.expand_dims(x_test, -1) / 255.0 | |
# Create model | |
model = models.Sequential() | |
model.add(layers.Conv2D(10, (5, 5), activation='relu', padding="same")) | |
model.add(layers.MaxPooling2D((2, 2))) | |
model.add(layers.Flatten()) | |
model.add(layers.Dense(10)) | |
# Compile and train model | |
model.compile(optimizer='adam', | |
loss=losses.SparseCategoricalCrossentropy(from_logits=True), | |
metrics=['accuracy']) | |
model.fit( | |
x_train[:10000], y_train[:10000], validation_data=(x_test, y_test), | |
batch_size=100, epochs=1, | |
) | |
# Save and load model | |
model_path = "keras_cnn_mnist" | |
model.save(model_path) | |
print("Saved model") | |
loaded_model = models.load_model(model_path) | |
print("Loaded model") | |
# Perform inference | |
t0 = perf_counter() | |
preds = model.predict(x_test) | |
t1 = perf_counter() | |
accuracy = list(preds.argmax(axis=1) == y_test).count(True) / y_test.size | |
print("Finished inference; time taken = {:.3f} s".format(t1 - t0)) | |
print("Accuracy of predictions = {:.2f} %".format(accuracy * 100)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment