Last active
June 5, 2019 21:33
-
-
Save amohant4/475dae25d40ddfe545b009595e93f603 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
loss_over_last_N_iters = [] # Keep track of loss in last N iterations | |
lr = 0.01 # can be anything | |
for global_step in range(0,total_steps): | |
learning_rate = tf.placeholder(tf.float32, shape=[]) | |
change_in_loss = get_loss_change(loss_over_last_N_iters) # determine if the loss is changing or has hit a plateau. | |
if change_in_loss > theta: | |
lr = lr*alpha # Change the learning rate (eg. make it lr/10) | |
# … | |
loss = … | |
train_step = tf.train.GradientDescentOptimizer( | |
learning_rate=learning_rate).minimize(mse) # create an optimizer with the placeholder input as learning rate | |
sess = tf.Session() | |
# Feed different values for learning rate to each training step. | |
error, _ = sess.run([loss, train_step], feed_dict={learning_rate: lr, data: …}) # pass the rule based lr in feed dict | |
loss_over_last_N_iters.append(0,error) # Get the new loss and update the list tracking loss | |
loss_over_last_N_iters.pop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment