Skip to content

Instantly share code, notes, and snippets.

@RomanSteinberg
Created July 24, 2018 14:20
Show Gist options
  • Save RomanSteinberg/07b51dbdd4e7165234c760d5c934c812 to your computer and use it in GitHub Desktop.
Save RomanSteinberg/07b51dbdd4e7165234c760d5c934c812 to your computer and use it in GitHub Desktop.
Dataset + Estimator + Keras Model
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.layers import Dense, Input, Dropout
def imgs_input_fn(filenames, labels=None, perform_shuffle=False, repeat_count=1, batch_size=1):
"""
Creates tf.data.Dataset object.
Args:
filenames (list:
labels (list):
perform_shuffle (bool):
repeat_count (int):
batch_size (int):
Returns (tuple):
Tuple contains images batch and corresponding labels batch.
"""
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image = tf.image.decode_image(image_string, channels=3)
image.set_shape([None, None, None])
image = tf.image.resize_images(image, [150, 150])
image = tf.div(tf.subtract(image, 127.5), 127.5) # normalization
return {'input_1': image}, label
if labels is None:
labels = [0]*len(filenames)
labels=np.array(labels)
# Expand the shape of "labels" if necessory
if len(labels.shape) == 1:
labels = np.expand_dims(labels, axis=1)
filenames = tf.constant(filenames)
labels = tf.constant(labels)
labels = tf.cast(labels, tf.float32)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
if perform_shuffle:
# Randomizes input using a window of 256 elements (read into memory)
dataset = dataset.shuffle(buffer_size=256)
dataset = dataset.repeat(repeat_count) # Repeats dataset this # times
dataset = dataset.batch(batch_size) # Batch size to use
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
def create_classification_net():
def create_tail(input_shape):
inp = Input(input_shape)
out = Dense(4096, activation='relu')(inp)
out = Dropout(0.75)(out)
out = Dense(1000, activation='relu')(out)
out = Dropout(0.75)(out)
return Model(inputs=[], outputs=[out])
# part 1
feature_extractor = ResNet50(weights='imagenet', pooling='max')
print(feature_extractor.output_shape)
# part 2
classifier_tail = create_tail(feature_extractor.output_shape)
# conect them
input_to_chain = Input(feature_extractor.input_shape)
out = feature_extractor(input_to_chain)
out = Flatten()(out)
out = classifier_tail(out)
model = Model(input=[input_to_chain], output=[out])
return model
def train(test_files, test_labels, experiment_folder):
model = create_classification_net()
model.compile(optimizer=Adam, loss='binary_crossentropy', metric='accuracy')
est_catvsdog = tf.keras.estimator.model_to_estimator(keras_model=model,
model_dir=experiment_folder)
train_input_fn = lambda: imgs_input_fn(test_files, test_labels,
perform_shuffle=True, repeat_count=5, batch_size=20)
train_spec = tf.estimator.TrainSpec(input_fn= train_input_fn, max_steps=500)
eval_input_fn = lambda: imgs_input_fn(test_files, labels=test_labels, perform_shuffle=False)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(est_catvsdog, train_spec, eval_spec)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment