Created
July 29, 2020 13:22
-
-
Save ntakouris/479f520f38650d49307406012cc8fa24 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
H_SIZE = 'h_size' | |
def _get_hyperparameters() -> kerastuner.HyperParameters: | |
hp = kerastuner.HyperParameters() | |
hp.Choice(H_SIZE, [5, 10]) | |
return hp | |
def _build_keras_model(hparams: kerastuner.HyperParameters) -> tf.keras.Model: | |
features_in = [] | |
features_in.extend(DENSE_FEATURES) | |
features_in.extend(BINARY_FEATURES) | |
features_in = [f'{x}_xf' for x in features_in] | |
input_layers = { | |
colname: tf.keras.layers.Input( | |
name=colname, shape=(None, 1), dtype=tf.float32) | |
for colname in features_in | |
} | |
x = tf.keras.layers.Concatenate(axis=-1)(input_layers.values()) | |
h = int(hparams.get(H_SIZE)) | |
x = tf.keras.layers.Dense( | |
units=h, activation='relu')(x) | |
out = tf.keras.layers.Dense(units=1, activation='sigmoid')(x) | |
model = tf.keras.Model(input_layers, out) | |
model.compile( | |
loss='binary_crossentropy', | |
optimizer='adam', | |
metrics=[tf.keras.metrics.BinaryAccuracy()]) | |
model.summary(print_fn=logging.info) | |
return model | |
def tuner_fn(fn_args: FnArgs) -> TunerFnResult: | |
train_files = fn_args.train_files | |
eval_files = fn_args.eval_files | |
tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path) | |
hparams = _get_hyperparameters() | |
tuner = kerastuner.Hyperband( | |
hypermodel=_build_keras_model, | |
hyperparameters=hparams, | |
objective=kerastuner.Objective('binary_accuracy', 'max'), | |
factor=3, | |
max_epochs=2, | |
directory=fn_args.working_dir, | |
project_name='ftfx:simple_e2e') | |
train_dataset = _input_fn(train_files, tf_transform_output) | |
eval_dataset = _input_fn(eval_files, tf_transform_output) | |
return TunerFnResult( | |
tuner=tuner, | |
fit_kwargs={ | |
'x': train_dataset, | |
'validation_data': eval_dataset, | |
'steps_per_epoch': fn_args.train_steps, | |
'validation_steps': fn_args.eval_steps | |
}) | |
def _get_serve_tf_examples_fn(model, tf_transform_output): | |
model.tft_layer = tf_transform_output.transform_features_layer() | |
@tf.function | |
def serve_tf_examples_fn(serialized_tf_examples): | |
feature_spec = tf_transform_output.raw_feature_spec() | |
feature_spec.pop(LABEL_KEY) | |
parsed_features = tf.io.parse_example( | |
serialized_tf_examples, feature_spec) | |
transformed_features = model.tft_layer(parsed_features) | |
return model(transformed_features) | |
return serve_tf_examples_fn | |
def run_fn(fn_args: TrainerFnArgs): | |
hparams = fn_args.hyperparameters | |
if type(hparams) is dict and 'values' in hparams.keys(): | |
hparams = hparams['values'] | |
schema = schema_pb2.Schema() | |
schema_text = file_io.read_file_to_string(fn_args.schema_file) | |
text_format.Parse(schema_text, schema) | |
feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec | |
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) | |
train_dataset = _input_fn(fn_args.train_files, tf_transform_output) | |
eval_dataset = _input_fn(fn_args.eval_files, tf_transform_output) | |
mirrored_strategy = tf.distribute.MirroredStrategy() | |
with mirrored_strategy.scope(): | |
model = _build_keras_model(hparams=hparams) | |
try: | |
log_dir = fn_args.model_run_dir | |
except KeyError: | |
log_dir = os.path.join(os.path.dirname( | |
fn_args.serving_model_dir), 'logs') | |
# Write logs to path | |
tensorboard_callback = tf.keras.callbacks.TensorBoard( | |
log_dir=log_dir, update_freq='batch') | |
model.fit( | |
train_dataset, | |
steps_per_epoch=fn_args.train_steps, | |
validation_data=eval_dataset, | |
validation_steps=fn_args.eval_steps, | |
callbacks=[tensorboard_callback]) | |
signatures = { | |
'serving_default': | |
_get_serve_tf_examples_fn(model, | |
tf_transform_output).get_concrete_function( | |
tf.TensorSpec( | |
shape=[None], | |
dtype=tf.string, | |
name='examples')) | |
} | |
model.save(fn_args.serving_model_dir, | |
save_format='tf', signatures=signatures) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment