Skip to content

Instantly share code, notes, and snippets.

@crhea93
Last active October 17, 2022 13:44
Show Gist options
  • Save crhea93/d130a2344d55541ab46326358692ffd7 to your computer and use it in GitHub Desktop.
Save crhea93/d130a2344d55541ab46326358692ffd7 to your computer and use it in GitHub Desktop.
RIM Guassians 3
# Load model and define hyper parameters
epochs = 100
batch_size = 16
model = RIM(rnn_units1=256, rnn_units2=256, conv_filters=8, kernel_size=2, input_size=n, dimensions=1, t_steps=10, learning_rate=0.005)
# Prepare the training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train, A_train, N_train))
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
train_dataset = train_dataset.prefetch(2)
# Prepare the validation dataset
val_dataset = tf.data.Dataset.from_tensor_slices((X_valid, Y_valid, A_valid, N_valid))
val_dataset = val_dataset.batch(batch_size, drop_remainder=True)
val_dataset = val_dataset.prefetch(2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment