Skip to content

Instantly share code, notes, and snippets.

@e-lin
Last active July 9, 2019 12:18
Show Gist options
  • Save e-lin/86b79eac2506dc64d766c42e7ba661e3 to your computer and use it in GitHub Desktop.
Save e-lin/86b79eac2506dc64d766c42e7ba661e3 to your computer and use it in GitHub Desktop.
AutoEncoder: Training one Autoencoder at a time in multiple graphs
def train_autoencoder(X_train, n_neurons, n_epochs, batch_size,
learning_rate = 0.01, l2_reg = 0.0005, seed=42,
hidden_activation=tf.nn.elu,
output_activation=tf.nn.elu):
graph = tf.Graph()
with graph.as_default():
tf.set_random_seed(seed)
n_inputs = X_train.shape[1]
X = tf.placeholder(tf.float32, shape=[None, n_inputs])
my_dense_layer = partial(
tf.layers.dense,
kernel_initializer=tf.contrib.layers.variance_scaling_initializer(),
kernel_regularizer=tf.contrib.layers.l2_regularizer(l2_reg))
hidden = my_dense_layer(X, n_neurons, activation=hidden_activation, name="hidden")
outputs = my_dense_layer(hidden, n_inputs, activation=output_activation, name="outputs")
reconstruction_loss = tf.reduce_mean(tf.square(outputs - X))
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
loss = tf.add_n([reconstruction_loss] + reg_losses)
optimizer = tf.train.AdamOptimizer(learning_rate)
training_op = optimizer.minimize(loss)
init = tf.global_variables_initializer()
with tf.Session(graph=graph) as sess:
init.run()
for epoch in range(n_epochs):
n_batches = len(X_train) // batch_size
for iteration in range(n_batches):
print("\r{}%".format(100 * iteration // n_batches), end="")
sys.stdout.flush()
indices = rnd.permutation(len(X_train))[:batch_size]
X_batch = X_train[indices]
sess.run(training_op, feed_dict={X: X_batch})
loss_train = reconstruction_loss.eval(feed_dict={X: X_batch})
print("\r{}".format(epoch), "Train MSE:", loss_train)
params = dict([(var.name, var.eval()) for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)])
hidden_val = hidden.eval(feed_dict={X: X_train})
return hidden_val, params["hidden/kernel:0"], params["hidden/bias:0"], params["outputs/kernel:0"], params["outputs/bias:0"]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment