Skip to content

Instantly share code, notes, and snippets.

@bstriner
Created June 1, 2017 01:30
Show Gist options
  • Select an option

  • Save bstriner/072bf2993cca32aadeb18e0b43833a1a to your computer and use it in GitHub Desktop.

Select an option

Save bstriner/072bf2993cca32aadeb18e0b43833a1a to your computer and use it in GitHub Desktop.
import keras.backend as K
from keras.callbacks import CSVLogger
from keras.datasets import mnist
from keras.layers import Input, Lambda, Dense, Flatten, BatchNormalization, Activation
from keras.models import Model
def main():
# Both inputs and targets are `Input` tensors
input_x = Input((28, 28), name='input_x', dtype='uint8') # uint8 [0-255]
y_true = Input((1,), name='y_true', dtype='uint8') # uint8 [0-9]
# Build prediction network as usual
h = Flatten()(input_x)
h = Lambda(lambda _x: K.cast(_x, 'float32'),
output_shape=lambda _x: _x,
name='cast')(h) # cast uint8 to float32
h = BatchNormalization()(h) # normalize pixels
for i in range(3): # hidden relu and batchnorm layers
h = Dense(256)(h)
h = BatchNormalization()(h)
h = Activation('relu')(h)
y_pred = Dense(10, activation='softmax', name='y_pred')(h) # softmax output layer
# Lambda layer performs loss calculation (negative log likelihood)
loss = Lambda(lambda (_yt, _yp): -K.log(_yp[K.reshape(K.arange(K.shape(_yt)[0]), (-1, 1)), _yt] + K.epsilon()),
output_shape=lambda (_yt, _yp): _yt,
name='loss')([y_true, y_pred])
# Model `inputs` are both x and y. `outputs` is the loss.
model = Model(inputs=[input_x, y_true], outputs=[loss])
# Manually add the loss to the model
model.add_loss(K.sum(loss, axis=None))
# Compile with the loss weight set to None, so it will be omitted
model.compile('adam', loss=[None], loss_weights=[None])
# Add accuracy to the metrics
# Cannot add as a metric to compile, because metrics for skipped outputs are skipped
accuracy = K.mean(K.equal(K.argmax(y_pred, axis=1), K.flatten(y_true)))
model.metrics_names.append('accuracy')
model.metrics_tensors.append(accuracy)
# Model summary
model.summary()
# Train model
train, test = mnist.load_data()
cb = CSVLogger("mnist_training.csv")
model.fit(list(train), [None], epochs=300, batch_size=128, callbacks=[cb], validation_data=(list(test), [None]))
if __name__ == "__main__":
main()
@bstriner
Copy link
Copy Markdown
Author

bstriner commented Jun 1, 2017

How to use None to omit outputs during training, so you can use an Input as both an input and a target. Model then has no target when you train.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment