Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created September 19, 2020 08:36
Show Gist options
  • Save ntakouris/9443f27f2704472188b3567d3f7a5640 to your computer and use it in GitHub Desktop.
Save ntakouris/9443f27f2704472188b3567d3f7a5640 to your computer and use it in GitHub Desktop.
def get_input_graph(input_feature_keys, input_window_size) -> Tuple[Input, tf.keras.layers.Layer]:
transformed_columns = [transformed_name(
key) for key in input_feature_keys]
input_layers = {
colname: Input(name=colname, shape=(
input_window_size), dtype=tf.float32)
for colname in transformed_columns
}
pre_model_input = Concatenate(axis=-1)(list(input_layers.values()))
pre_model_input = Reshape(target_shape=(input_window_size, len(input_feature_keys)))(
pre_model_input)
return input_layers, pre_model_input
def get_output_graph(head_layer, predict_feature_keys, output_window_size) -> Dict[Text, tf.keras.layers.Layer]:
return {
colname: Dense(units=output_window_size, name=colname)(head_layer)
for colname in predict_feature_keys
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment