Created
September 19, 2020 09:40
-
-
Save ntakouris/30e80d1178e939bb5f9bf3522fe8d2ef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from functools import partial | |
from typing import Any, Dict, List, Text | |
import kerastuner | |
import tensorflow as tf | |
import tensorflow.keras.backend as K | |
import tensorflow_transform as tft | |
import tensorflow_data_validation as tfdv | |
from absl import logging | |
from tensorflow_transform.tf_metadata import schema_utils | |
from tfx.components.trainer.fn_args_utils import FnArgs | |
from tfx.components.tuner.component import TunerFnResult | |
from rnn.constants import (BATCH_SIZE, DENSE_FLOAT_FEATURE_KEYS, FEATURE_KEYS, PREDICT_FEATURE_KEYS, | |
INPUT_FEATURE_KEYS, INPUT_WINDOW_SIZE, OUTPUT_WINDOW_SIZE, | |
HYPERPARAMETERS) | |
from rnn.model import build_keras_model | |
from input_fn_utils import input_fn, get_serve_raw_fn, _preprocessing_fn | |
def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]: | |
return _preprocessing_fn(inputs, | |
dense_float_feature_keys=DENSE_FLOAT_FEATURE_KEYS, | |
input_feature_keys=INPUT_FEATURE_KEYS) | |
def _input_fn(train_files, tf_transform_output, feature_spec): | |
return input_fn(train_files, tf_transform_output, | |
feature_spec=feature_spec, | |
input_window_size=INPUT_WINDOW_SIZE, | |
output_window_size=OUTPUT_WINDOW_SIZE, | |
batch_size=BATCH_SIZE, | |
predict_feature_keys=PREDICT_FEATURE_KEYS, | |
feature_keys=FEATURE_KEYS, | |
input_feature_keys=INPUT_FEATURE_KEYS) | |
def tuner_fn(fn_args: FnArgs) -> TunerFnResult: | |
# ... | |
return TunerFnResult( | |
tuner=tuner, | |
fit_kwargs={ | |
'x': train_dataset, | |
'validation_data': eval_dataset, | |
'steps_per_epoch': fn_args.train_steps, | |
'validation_steps': fn_args.eval_steps | |
}) | |
def run_fn(fn_args): | |
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) | |
train_files = fn_args.train_files | |
eval_files = fn_args.eval_files | |
serving_model_dir = fn_args.serving_model_dir | |
train_steps = fn_args.train_steps | |
eval_steps = fn_args.eval_steps | |
schema = tfdv.load_schema_text(fn_args.schema_file) | |
feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec | |
hparams = fn_args.hyperparameters | |
if type(hparams) is dict and 'values' in hparams.keys(): | |
hparams = hparams['values'] | |
train_dataset = _input_fn(train_files, tf_transform_output, feature_spec) | |
eval_dataset = _input_fn(eval_files, tf_transform_output, feature_spec) | |
mirrored_strategy = tf.distribute.MirroredStrategy() | |
with mirrored_strategy.scope(): | |
model = build_keras_model(hparams=hparams) | |
log_dir = os.path.join(os.path.dirname(serving_model_dir), 'logs') | |
tensorboard_callback = tf.keras.callbacks.TensorBoard( | |
log_dir=log_dir, update_freq='epoch') | |
model.fit( | |
train_dataset, | |
steps_per_epoch=train_steps, | |
validation_data=eval_dataset, | |
validation_steps=eval_steps, | |
callbacks=[tensorboard_callback]) | |
serving_raw_entry = get_serve_raw_fn( | |
model, tf_transform_output, INPUT_WINDOW_SIZE) | |
serving_raw_signature_tensorspecs = {x: tf.TensorSpec( | |
shape=[None, INPUT_WINDOW_SIZE], dtype=tf.float32, name=x) for x in INPUT_FEATURE_KEYS} | |
logging.info( | |
f'serving_raw signature TensorSpecs are: {serving_raw_signature_tensorspecs}') | |
signatures = { | |
'serving_raw': serving_raw_entry.get_concrete_function(serving_raw_signature_tensorspecs), | |
} | |
model.save(serving_model_dir, save_format='tf', | |
signatures=signatures) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment