Created
October 30, 2017 08:26
-
-
Save bzamecnik/f76e480edf98e95ab263fd1a123af7a5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/fchollet/keras/pull/8286 | |
# | |
# An example how pass additional substitutions to the training function | |
# via TensorFlow feed_dict argument to tf.Session.run(). | |
# | |
# Note that `feed_dict` keys are `tf.Placeholder`s and values can be | |
# ordinary numpy arrays or other Python values. | |
# | |
# We pass additional arguments to model.compile() -> K.function() as **kwargs. | |
# The trick is that the feed_dict is passed as a reference, ie. even though | |
# the dictionary itself is constant, we can modify the values! | |
# | |
# Normally K.function() accepts Keras input placeholders and values in its | |
# `__call__()` method. This `feed_dict` allows to pass additional substitutions | |
# that Keras doesn't know about. This might be useful eg. to feed inputs for | |
# pipelining. | |
# | |
# Originally I thought it's not possible to pass variable values this way | |
# but thanks to @TimZaman for correcting this idea | |
# (https://github.com/fchollet/keras/pull/8286#issuecomment-340285140). | |
import keras.backend as K | |
from keras.callbacks import Callback | |
from keras.models import Model | |
from keras.layers import Dense, Input | |
from keras.utils import to_categorical | |
import numpy as np | |
import tensorflow as tf | |
def create_synth_dataset(input_shape, class_count, dataset_size): | |
X = np.random.rand(*((dataset_size,) + input_shape)).astype(np.float32) | |
y = np.random.randint(low=0, high=class_count, size=dataset_size) | |
y = to_categorical(y, class_count).astype(np.float32) | |
return X, y | |
num_classes = 2 | |
x_train, y_train = create_synth_dataset((10,), num_classes, 50) | |
image = Input(shape=(10,)) | |
x = Dense(1, activation='relu')(image) | |
output = Dense(num_classes, activation='softmax')(x) | |
model = Model(inputs=image, outputs=output) | |
class FooCallback(Callback): | |
def __init__(self): | |
self.x = tf.placeholder(dtype=tf.int32, shape=()) | |
self.y = tf.Variable(42, trainable=False, collections=[], validate_shape=False) | |
self.feed_dict = {self.x: 0} | |
self.fetches = [tf.assign(self.y, self.x)] | |
def _get_value(self): | |
return K.get_session().run(self.y) | |
def on_train_begin(self, logs=None): | |
K.get_session().run(self.y.initializer) | |
print('before training: y =', self._get_value()) | |
def on_batch_begin(self, batch, logs=None): | |
# update the value in feed_dict, it will get propagated | |
self.feed_dict[self.x] = batch * 100 | |
def on_batch_end(self, batch, logs=None): | |
print('after batch %d: y =' % batch, self._get_value()) | |
foo = FooCallback() | |
model.compile(optimizer='sgd', loss='categorical_crossentropy', | |
feed_dict=foo.feed_dict, fetches=foo.fetches) | |
model.fit(x_train, y_train, callbacks=[foo], batch_size=10, epochs=2, verbose=2); | |
# before training: y = 42 | |
# Epoch 1/2 | |
# after batch 0: y = 0 | |
# after batch 1: y = 100 | |
# after batch 2: y = 200 | |
# after batch 3: y = 300 | |
# after batch 4: y = 400 | |
# - 0s - loss: 0.6981 | |
# Epoch 2/2 | |
# after batch 0: y = 0 | |
# after batch 1: y = 100 | |
# after batch 2: y = 200 | |
# after batch 3: y = 300 | |
# after batch 4: y = 400 | |
# - 0s - loss: 0.6968 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment