Created
November 19, 2020 19:27
-
-
Save ntakouris/3ae3ce252ebe8a4aa445f032e1ebed3c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from typing import Any, Dict, List, Text | |
from absl import logging | |
import tensorflow as tf | |
import tensorflow.keras.backend as K | |
import tensorflow_transform as tft | |
import tensorflow_data_validation as tfdv | |
def transformed_name(key: Text) -> Text: | |
return key + '_xf' | |
def _preprocessing_fn(inputs: Dict[Text, Any], dense_float_feature_keys, input_feature_keys) -> Dict[Text, Any]: | |
""" | |
Performs feature selection and preprocessing based on each feature key set. | |
""" | |
outputs = {} | |
for key in [k for k in dense_float_feature_keys if k in input_feature_keys]: | |
outputs[transformed_name(key) | |
] = tft.scale_to_z_score(inputs[key]) | |
return outputs | |
def get_apply_tft_layer(tf_transform_output, window_size): | |
""" | |
Applies the TFT layer designed for [None, 1] input shapes, | |
to [None, <window>] shapes, by performing reshaping on the inputs | |
and dimension expansion on the outputs. | |
""" | |
tft_layer = tf_transform_output.transform_features_layer() | |
@tf.function | |
def apply_tf_transform(raw_features_dict): | |
unbatched_raw_features = { # unbatch raw_features_dict by flattenning | |
k: K.expand_dims(K.flatten(v)) | |
for k, v in raw_features_dict.items() | |
} | |
transformed_features = tft_layer(unbatched_raw_features) | |
expanded_dims = { | |
k: K.reshape(v, [-1, window_size]) | |
for k, v in transformed_features.items() | |
} # expand dimensions to (None, window, ) | |
return expanded_dims | |
return apply_tf_transform | |
def get_serve_raw_fn(model, tf_transform_output, window_size): | |
""" | |
Returns a tf.function that preforms preprocessing and inference. | |
For usage as the seving_raw signature. | |
""" | |
logging.info('_get_serve_raw_fn') | |
model.preprocessing_layer = get_apply_tft_layer( | |
tf_transform_output, window_size) | |
@tf.function | |
def serve_raw_fn(features): | |
preprocessed_features = model.preprocessing_layer(features) | |
return model(preprocessed_features) | |
return serve_raw_fn | |
def get_apply_tft_map_fn(tf_transform_output, window_size): | |
""" | |
Returns a tf.function that applies the TFT layer given a window size, | |
while also doing y-passthrough. For usage by tf.data.Dataset (training) | |
""" | |
apply_tft_layer_fn = get_apply_tft_layer(tf_transform_output, window_size) | |
@tf.function | |
def apply_tft_y_passthrough(raw_features_dict, y): | |
return apply_tft_layer_fn(raw_features_dict), y | |
return apply_tft_y_passthrough | |
@tf.function | |
def sub_to_batch(sub, window): | |
return sub.batch(window) | |
@tf.function | |
def unmarshal(x, input_feature_keys): | |
""" | |
'flattens' x by extracting input_feature_keys and stacking them as a tensor, | |
so that you can perform windowing operations. | |
""" | |
t = [x[k] for k in input_feature_keys] | |
return tf.reshape(t, [len(input_feature_keys)]) | |
@tf.function | |
def marshal(x, y, input_feature_keys, predict_feature_keys, feature_keys): | |
""" | |
Get indices of FEATURE KEYS in INPUT_FEATURE_KEYS and map 1-1 to features_x | |
""" | |
features_x = { | |
k: x[:, feature_keys.index(k)] for k in input_feature_keys | |
} | |
features_y = { | |
k: x[:, feature_keys.index(k)] for k in predict_feature_keys | |
} | |
return features_x, features_y | |
@tf.function | |
def collate_fn(x, x_idx_start, x_idx_num, y_idx_start, y_idx_num, input_feature_keys): | |
""" | |
Using the input tensor x, outputs x[idx_start:x_idx_num], x[y_idx_start:y_idx_num] by tf ops. | |
""" | |
x_out = tf.slice(x, [x_idx_start, 0], [x_idx_num, len(input_feature_keys)]) | |
y_out = tf.slice(x, [y_idx_start, 0], [y_idx_num, len(input_feature_keys)]) | |
return x_out, y_out | |
def gzip_reader_fn(filenames): | |
return tf.data.TFRecordDataset(filenames, compression_type='GZIP') | |
def input_fn(file_pattern, tf_transform_output, | |
input_window_size, output_window_size, | |
feature_spec, | |
feature_keys, input_feature_keys, predict_feature_keys, | |
batch_size=256): | |
""" | |
Performs all the required preprocessing for model training and/or hyperparameter | |
search. | |
First, it unmarshals the data to perform windowing and batching operations, in order | |
to keep the data in correct sequence order. | |
Then, marhalling is performed again so that named inputs are available for the | |
keras model. | |
Finally, batching and an in-memory shuffling is performed, along with the preprocessing | |
layer. | |
Returns: | |
A `tf.data.Dataset` | |
""" | |
big_window = input_window_size + output_window_size | |
collate_fn_partial = partial(collate_fn, x_idx_start=0, x_idx_num=input_window_size, | |
y_idx_start=input_window_size, y_idx_num=output_window_size, | |
input_feature_keys=input_feature_keys) | |
unmarshal_fn_partial = partial(unmarshal, | |
input_feature_keys=input_feature_keys) | |
marshal_fn_partial = partial(marshal, | |
input_feature_keys=input_feature_keys, | |
predict_feature_keys=predict_feature_keys, | |
feature_keys=feature_keys) | |
apply_tf_transform_map_fn = get_apply_tft_map_fn( | |
tf_transform_output, input_window_size) | |
dataset = tf.data.experimental.make_batched_features_dataset( | |
file_pattern=file_pattern, | |
features=feature_spec, | |
reader=gzip_reader_fn, | |
shuffle=False, | |
sloppy_ordering=False, | |
batch_size=big_window * batch_size * 8) \ | |
.unbatch() \ | |
.map(unmarshal_fn_partial, num_parallel_calls=tf.data.experimental.AUTOTUNE) \ | |
.window(big_window, shift=1) \ | |
.flat_map(partial(sub_to_batch, window=big_window)) \ | |
.map(collate_fn_partial, num_parallel_calls=tf.data.experimental.AUTOTUNE) \ | |
.shuffle(big_window * batch_size * 8, reshuffle_each_iteration=True) \ | |
.map(marshal_fn_partial, num_parallel_calls=tf.data.experimental.AUTOTUNE) \ | |
.batch(batch_size) \ | |
.map(apply_tf_transform_map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) \ | |
.prefetch(tf.data.experimental.AUTOTUNE) | |
return dataset |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment