Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created September 19, 2020 09:40
Show Gist options
  • Save ntakouris/30e80d1178e939bb5f9bf3522fe8d2ef to your computer and use it in GitHub Desktop.
Save ntakouris/30e80d1178e939bb5f9bf3522fe8d2ef to your computer and use it in GitHub Desktop.
import os
from functools import partial
from typing import Any, Dict, List, Text
import kerastuner
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow_transform as tft
import tensorflow_data_validation as tfdv
from absl import logging
from tensorflow_transform.tf_metadata import schema_utils
from tfx.components.trainer.fn_args_utils import FnArgs
from tfx.components.tuner.component import TunerFnResult
from rnn.constants import (BATCH_SIZE, DENSE_FLOAT_FEATURE_KEYS, FEATURE_KEYS, PREDICT_FEATURE_KEYS,
INPUT_FEATURE_KEYS, INPUT_WINDOW_SIZE, OUTPUT_WINDOW_SIZE,
HYPERPARAMETERS)
from rnn.model import build_keras_model
from input_fn_utils import input_fn, get_serve_raw_fn, _preprocessing_fn
def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]:
return _preprocessing_fn(inputs,
dense_float_feature_keys=DENSE_FLOAT_FEATURE_KEYS,
input_feature_keys=INPUT_FEATURE_KEYS)
def _input_fn(train_files, tf_transform_output, feature_spec):
return input_fn(train_files, tf_transform_output,
feature_spec=feature_spec,
input_window_size=INPUT_WINDOW_SIZE,
output_window_size=OUTPUT_WINDOW_SIZE,
batch_size=BATCH_SIZE,
predict_feature_keys=PREDICT_FEATURE_KEYS,
feature_keys=FEATURE_KEYS,
input_feature_keys=INPUT_FEATURE_KEYS)
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
# ...
return TunerFnResult(
tuner=tuner,
fit_kwargs={
'x': train_dataset,
'validation_data': eval_dataset,
'steps_per_epoch': fn_args.train_steps,
'validation_steps': fn_args.eval_steps
})
def run_fn(fn_args):
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
train_files = fn_args.train_files
eval_files = fn_args.eval_files
serving_model_dir = fn_args.serving_model_dir
train_steps = fn_args.train_steps
eval_steps = fn_args.eval_steps
schema = tfdv.load_schema_text(fn_args.schema_file)
feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec
hparams = fn_args.hyperparameters
if type(hparams) is dict and 'values' in hparams.keys():
hparams = hparams['values']
train_dataset = _input_fn(train_files, tf_transform_output, feature_spec)
eval_dataset = _input_fn(eval_files, tf_transform_output, feature_spec)
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = build_keras_model(hparams=hparams)
log_dir = os.path.join(os.path.dirname(serving_model_dir), 'logs')
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=log_dir, update_freq='epoch')
model.fit(
train_dataset,
steps_per_epoch=train_steps,
validation_data=eval_dataset,
validation_steps=eval_steps,
callbacks=[tensorboard_callback])
serving_raw_entry = get_serve_raw_fn(
model, tf_transform_output, INPUT_WINDOW_SIZE)
serving_raw_signature_tensorspecs = {x: tf.TensorSpec(
shape=[None, INPUT_WINDOW_SIZE], dtype=tf.float32, name=x) for x in INPUT_FEATURE_KEYS}
logging.info(
f'serving_raw signature TensorSpecs are: {serving_raw_signature_tensorspecs}')
signatures = {
'serving_raw': serving_raw_entry.get_concrete_function(serving_raw_signature_tensorspecs),
}
model.save(serving_model_dir, save_format='tf',
signatures=signatures)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment