Created
September 25, 2020 11:36
-
-
Save ntakouris/5148e5eadd99cbb660224268c49d1f93 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_input_graph(input_feature_keys) -> Tuple[Input, tf.keras.layers.Layer]: | |
""" | |
Creates the named input layers, strips the column names and provides | |
them as a plain tensor. | |
Returns: | |
Tuple[Input, tf.keras.layers.Layer]: Input for your model -- A layer with output shape | |
[None (batch size), input_window_size, len(input_feature_keys)] | |
""" | |
# if you are using Tensorflow Transform or Tensorflow Extended | |
transformed_columns = [transformed_name( | |
key) for key in input_feature_keys] | |
# create the input dict of layers based | |
input_layers = { | |
colname: Input(name=colname, shape=(1,), dtype=tf.float32) | |
for colname in transformed_columns | |
} | |
# concatenate everything and end up with [None, len(input_feature_keys)] outputs | |
pre_model_input = Concatenate(axis=-1)(list(input_layers.values())) | |
pre_model_input = Reshape(target_shape=(len(input_feature_keys),))( | |
pre_model_input) | |
return input_layers, pre_model_input | |
def get_output_graph(head_layer, predict_feature_keys) -> Dict[Text, tf.keras.layers.Layer]: | |
""" | |
Transforms a plain-tensor feature layer output to named output layers. | |
Args: | |
head_layer ([type]): The final feature layer of your model | |
Returns: | |
Dict[Text, tf.keras.layers.Layer]: Named Dense layer outputs based on predict_feature_keys | |
""" | |
return { | |
colname: Dense(units=1, name=colname)(head_layer) | |
for colname in predict_feature_keys | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment