Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created July 29, 2020 13:22
Show Gist options
  • Save ntakouris/479f520f38650d49307406012cc8fa24 to your computer and use it in GitHub Desktop.
Save ntakouris/479f520f38650d49307406012cc8fa24 to your computer and use it in GitHub Desktop.
H_SIZE = 'h_size'
def _get_hyperparameters() -> kerastuner.HyperParameters:
hp = kerastuner.HyperParameters()
hp.Choice(H_SIZE, [5, 10])
return hp
def _build_keras_model(hparams: kerastuner.HyperParameters) -> tf.keras.Model:
features_in = []
features_in.extend(DENSE_FEATURES)
features_in.extend(BINARY_FEATURES)
features_in = [f'{x}_xf' for x in features_in]
input_layers = {
colname: tf.keras.layers.Input(
name=colname, shape=(None, 1), dtype=tf.float32)
for colname in features_in
}
x = tf.keras.layers.Concatenate(axis=-1)(input_layers.values())
h = int(hparams.get(H_SIZE))
x = tf.keras.layers.Dense(
units=h, activation='relu')(x)
out = tf.keras.layers.Dense(units=1, activation='sigmoid')(x)
model = tf.keras.Model(input_layers, out)
model.compile(
loss='binary_crossentropy',
optimizer='adam',
metrics=[tf.keras.metrics.BinaryAccuracy()])
model.summary(print_fn=logging.info)
return model
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
train_files = fn_args.train_files
eval_files = fn_args.eval_files
tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path)
hparams = _get_hyperparameters()
tuner = kerastuner.Hyperband(
hypermodel=_build_keras_model,
hyperparameters=hparams,
objective=kerastuner.Objective('binary_accuracy', 'max'),
factor=3,
max_epochs=2,
directory=fn_args.working_dir,
project_name='ftfx:simple_e2e')
train_dataset = _input_fn(train_files, tf_transform_output)
eval_dataset = _input_fn(eval_files, tf_transform_output)
return TunerFnResult(
tuner=tuner,
fit_kwargs={
'x': train_dataset,
'validation_data': eval_dataset,
'steps_per_epoch': fn_args.train_steps,
'validation_steps': fn_args.eval_steps
})
def _get_serve_tf_examples_fn(model, tf_transform_output):
model.tft_layer = tf_transform_output.transform_features_layer()
@tf.function
def serve_tf_examples_fn(serialized_tf_examples):
feature_spec = tf_transform_output.raw_feature_spec()
feature_spec.pop(LABEL_KEY)
parsed_features = tf.io.parse_example(
serialized_tf_examples, feature_spec)
transformed_features = model.tft_layer(parsed_features)
return model(transformed_features)
return serve_tf_examples_fn
def run_fn(fn_args: TrainerFnArgs):
hparams = fn_args.hyperparameters
if type(hparams) is dict and 'values' in hparams.keys():
hparams = hparams['values']
schema = schema_pb2.Schema()
schema_text = file_io.read_file_to_string(fn_args.schema_file)
text_format.Parse(schema_text, schema)
feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
train_dataset = _input_fn(fn_args.train_files, tf_transform_output)
eval_dataset = _input_fn(fn_args.eval_files, tf_transform_output)
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = _build_keras_model(hparams=hparams)
try:
log_dir = fn_args.model_run_dir
except KeyError:
log_dir = os.path.join(os.path.dirname(
fn_args.serving_model_dir), 'logs')
# Write logs to path
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir, update_freq='batch')
model.fit(
train_dataset,
steps_per_epoch=fn_args.train_steps,
validation_data=eval_dataset,
validation_steps=fn_args.eval_steps,
callbacks=[tensorboard_callback])
signatures = {
'serving_default':
_get_serve_tf_examples_fn(model,
tf_transform_output).get_concrete_function(
tf.TensorSpec(
shape=[None],
dtype=tf.string,
name='examples'))
}
model.save(fn_args.serving_model_dir,
save_format='tf', signatures=signatures)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment