Skip to content

Instantly share code, notes, and snippets.

@jakelevi1996
Created July 10, 2020 14:02
Show Gist options
  • Save jakelevi1996/3a390ca8206518e6c4f7aabd25b6c268 to your computer and use it in GitHub Desktop.
Save jakelevi1996/3a390ca8206518e6c4f7aabd25b6c268 to your computer and use it in GitHub Desktop.
Single script example of training a CNN on MNIST using tensorflow.keras
from time import perf_counter
import numpy as np
from tensorflow.keras import datasets, layers, models, losses
# Get training and test data
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1) / 255.0
x_test = np.expand_dims(x_test, -1) / 255.0
# Create model
model = models.Sequential()
model.add(layers.Conv2D(10, (5, 5), activation='relu', padding="same"))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(10))
# Compile and train model
model.compile(optimizer='adam',
loss=losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
x_train[:10000], y_train[:10000], validation_data=(x_test, y_test),
batch_size=100, epochs=1,
)
# Save and load model
model_path = "keras_cnn_mnist"
model.save(model_path)
print("Saved model")
loaded_model = models.load_model(model_path)
print("Loaded model")
# Perform inference
t0 = perf_counter()
preds = model.predict(x_test)
t1 = perf_counter()
accuracy = list(preds.argmax(axis=1) == y_test).count(True) / y_test.size
print("Finished inference; time taken = {:.3f} s".format(t1 - t0))
print("Accuracy of predictions = {:.2f} %".format(accuracy * 100))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment