Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created November 19, 2020 19:27
Show Gist options
  • Save ntakouris/3ae3ce252ebe8a4aa445f032e1ebed3c to your computer and use it in GitHub Desktop.
Save ntakouris/3ae3ce252ebe8a4aa445f032e1ebed3c to your computer and use it in GitHub Desktop.
from functools import partial
from typing import Any, Dict, List, Text
from absl import logging
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow_transform as tft
import tensorflow_data_validation as tfdv
def transformed_name(key: Text) -> Text:
return key + '_xf'
def _preprocessing_fn(inputs: Dict[Text, Any], dense_float_feature_keys, input_feature_keys) -> Dict[Text, Any]:
"""
Performs feature selection and preprocessing based on each feature key set.
"""
outputs = {}
for key in [k for k in dense_float_feature_keys if k in input_feature_keys]:
outputs[transformed_name(key)
] = tft.scale_to_z_score(inputs[key])
return outputs
def get_apply_tft_layer(tf_transform_output, window_size):
"""
Applies the TFT layer designed for [None, 1] input shapes,
to [None, <window>] shapes, by performing reshaping on the inputs
and dimension expansion on the outputs.
"""
tft_layer = tf_transform_output.transform_features_layer()
@tf.function
def apply_tf_transform(raw_features_dict):
unbatched_raw_features = { # unbatch raw_features_dict by flattenning
k: K.expand_dims(K.flatten(v))
for k, v in raw_features_dict.items()
}
transformed_features = tft_layer(unbatched_raw_features)
expanded_dims = {
k: K.reshape(v, [-1, window_size])
for k, v in transformed_features.items()
} # expand dimensions to (None, window, )
return expanded_dims
return apply_tf_transform
def get_serve_raw_fn(model, tf_transform_output, window_size):
"""
Returns a tf.function that preforms preprocessing and inference.
For usage as the seving_raw signature.
"""
logging.info('_get_serve_raw_fn')
model.preprocessing_layer = get_apply_tft_layer(
tf_transform_output, window_size)
@tf.function
def serve_raw_fn(features):
preprocessed_features = model.preprocessing_layer(features)
return model(preprocessed_features)
return serve_raw_fn
def get_apply_tft_map_fn(tf_transform_output, window_size):
"""
Returns a tf.function that applies the TFT layer given a window size,
while also doing y-passthrough. For usage by tf.data.Dataset (training)
"""
apply_tft_layer_fn = get_apply_tft_layer(tf_transform_output, window_size)
@tf.function
def apply_tft_y_passthrough(raw_features_dict, y):
return apply_tft_layer_fn(raw_features_dict), y
return apply_tft_y_passthrough
@tf.function
def sub_to_batch(sub, window):
return sub.batch(window)
@tf.function
def unmarshal(x, input_feature_keys):
"""
'flattens' x by extracting input_feature_keys and stacking them as a tensor,
so that you can perform windowing operations.
"""
t = [x[k] for k in input_feature_keys]
return tf.reshape(t, [len(input_feature_keys)])
@tf.function
def marshal(x, y, input_feature_keys, predict_feature_keys, feature_keys):
"""
Get indices of FEATURE KEYS in INPUT_FEATURE_KEYS and map 1-1 to features_x
"""
features_x = {
k: x[:, feature_keys.index(k)] for k in input_feature_keys
}
features_y = {
k: x[:, feature_keys.index(k)] for k in predict_feature_keys
}
return features_x, features_y
@tf.function
def collate_fn(x, x_idx_start, x_idx_num, y_idx_start, y_idx_num, input_feature_keys):
"""
Using the input tensor x, outputs x[idx_start:x_idx_num], x[y_idx_start:y_idx_num] by tf ops.
"""
x_out = tf.slice(x, [x_idx_start, 0], [x_idx_num, len(input_feature_keys)])
y_out = tf.slice(x, [y_idx_start, 0], [y_idx_num, len(input_feature_keys)])
return x_out, y_out
def gzip_reader_fn(filenames):
return tf.data.TFRecordDataset(filenames, compression_type='GZIP')
def input_fn(file_pattern, tf_transform_output,
input_window_size, output_window_size,
feature_spec,
feature_keys, input_feature_keys, predict_feature_keys,
batch_size=256):
"""
Performs all the required preprocessing for model training and/or hyperparameter
search.
First, it unmarshals the data to perform windowing and batching operations, in order
to keep the data in correct sequence order.
Then, marhalling is performed again so that named inputs are available for the
keras model.
Finally, batching and an in-memory shuffling is performed, along with the preprocessing
layer.
Returns:
A `tf.data.Dataset`
"""
big_window = input_window_size + output_window_size
collate_fn_partial = partial(collate_fn, x_idx_start=0, x_idx_num=input_window_size,
y_idx_start=input_window_size, y_idx_num=output_window_size,
input_feature_keys=input_feature_keys)
unmarshal_fn_partial = partial(unmarshal,
input_feature_keys=input_feature_keys)
marshal_fn_partial = partial(marshal,
input_feature_keys=input_feature_keys,
predict_feature_keys=predict_feature_keys,
feature_keys=feature_keys)
apply_tf_transform_map_fn = get_apply_tft_map_fn(
tf_transform_output, input_window_size)
dataset = tf.data.experimental.make_batched_features_dataset(
file_pattern=file_pattern,
features=feature_spec,
reader=gzip_reader_fn,
shuffle=False,
sloppy_ordering=False,
batch_size=big_window * batch_size * 8) \
.unbatch() \
.map(unmarshal_fn_partial, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.window(big_window, shift=1) \
.flat_map(partial(sub_to_batch, window=big_window)) \
.map(collate_fn_partial, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.shuffle(big_window * batch_size * 8, reshuffle_each_iteration=True) \
.map(marshal_fn_partial, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.batch(batch_size) \
.map(apply_tf_transform_map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment