Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Created October 30, 2017 08:26
Show Gist options
  • Save bzamecnik/f76e480edf98e95ab263fd1a123af7a5 to your computer and use it in GitHub Desktop.
Save bzamecnik/f76e480edf98e95ab263fd1a123af7a5 to your computer and use it in GitHub Desktop.
# https://github.com/fchollet/keras/pull/8286
#
# An example how pass additional substitutions to the training function
# via TensorFlow feed_dict argument to tf.Session.run().
#
# Note that `feed_dict` keys are `tf.Placeholder`s and values can be
# ordinary numpy arrays or other Python values.
#
# We pass additional arguments to model.compile() -> K.function() as **kwargs.
# The trick is that the feed_dict is passed as a reference, ie. even though
# the dictionary itself is constant, we can modify the values!
#
# Normally K.function() accepts Keras input placeholders and values in its
# `__call__()` method. This `feed_dict` allows to pass additional substitutions
# that Keras doesn't know about. This might be useful eg. to feed inputs for
# pipelining.
#
# Originally I thought it's not possible to pass variable values this way
# but thanks to @TimZaman for correcting this idea
# (https://github.com/fchollet/keras/pull/8286#issuecomment-340285140).
import keras.backend as K
from keras.callbacks import Callback
from keras.models import Model
from keras.layers import Dense, Input
from keras.utils import to_categorical
import numpy as np
import tensorflow as tf
def create_synth_dataset(input_shape, class_count, dataset_size):
X = np.random.rand(*((dataset_size,) + input_shape)).astype(np.float32)
y = np.random.randint(low=0, high=class_count, size=dataset_size)
y = to_categorical(y, class_count).astype(np.float32)
return X, y
num_classes = 2
x_train, y_train = create_synth_dataset((10,), num_classes, 50)
image = Input(shape=(10,))
x = Dense(1, activation='relu')(image)
output = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=image, outputs=output)
class FooCallback(Callback):
def __init__(self):
self.x = tf.placeholder(dtype=tf.int32, shape=())
self.y = tf.Variable(42, trainable=False, collections=[], validate_shape=False)
self.feed_dict = {self.x: 0}
self.fetches = [tf.assign(self.y, self.x)]
def _get_value(self):
return K.get_session().run(self.y)
def on_train_begin(self, logs=None):
K.get_session().run(self.y.initializer)
print('before training: y =', self._get_value())
def on_batch_begin(self, batch, logs=None):
# update the value in feed_dict, it will get propagated
self.feed_dict[self.x] = batch * 100
def on_batch_end(self, batch, logs=None):
print('after batch %d: y =' % batch, self._get_value())
foo = FooCallback()
model.compile(optimizer='sgd', loss='categorical_crossentropy',
feed_dict=foo.feed_dict, fetches=foo.fetches)
model.fit(x_train, y_train, callbacks=[foo], batch_size=10, epochs=2, verbose=2);
# before training: y = 42
# Epoch 1/2
# after batch 0: y = 0
# after batch 1: y = 100
# after batch 2: y = 200
# after batch 3: y = 300
# after batch 4: y = 400
# - 0s - loss: 0.6981
# Epoch 2/2
# after batch 0: y = 0
# after batch 1: y = 100
# after batch 2: y = 200
# after batch 3: y = 300
# after batch 4: y = 400
# - 0s - loss: 0.6968
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment