Created
July 17, 2018 13:35
-
-
Save risenW/61d98077f625299383c359bc2bee4c15 to your computer and use it in GitHub Desktop.
Second part of code for batch training
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
loss = tf.reduce_mean(tf.square(Y_pred - Y_target)) | |
# Declare the optimizer (G.D) | |
my_opt = tf.train.GradientDescentOptimizer(0.02) | |
train_step = my_opt.minimize(loss) | |
loss_batch = [] | |
for i in range(100): | |
#pick a random 20 data points | |
rand_index = np.random.choice(100, size=batch_size) | |
x_batch = np.transpose([x_vals[rand_index]]) # Transpose to the correct shape | |
y_batch = np.transpose([y_vals[rand_index]]) | |
sess.run(train_step, feed_dict={X_data: x_batch, Y_target:y_batch}) | |
#Print the result after 5 intervals | |
if(i+1) % 5 == 0: | |
print('Step #', str(i+1), 'W = ', str(sess.run(W))) | |
temp_loss = sess.run(loss, feed_dict={X_data: x_batch, Y_target:y_batch}) | |
loss_batch.append(temp_loss) | |
print('Loss = ', temp_loss) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment