Created
June 16, 2018 14:26
-
-
Save kovasb/ced7061fd0c895c0782dd5fa28de1728 to your computer and use it in GitHub Desktop.
Training Loop
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Example training loop for fitting a line from 2 points | |
# A line is defined as a*x+b | |
# Want machine to learn what a and b are. | |
# Important thing to note is the overall structure of components | |
# 1. Batch of training data | |
# 1A: inputs used to generate predictions | |
x_values = tf.convert_to_tensor([0.0, 1.0]) | |
# 1B: desired outputs or 'labels' | |
y_values = tf.convert_to_tensor([0.0, 3.0]) | |
# 2. Model | |
# 2A: Variables that will be trained | |
A = tfe.Variable(0.0) | |
B = tfe.Variable(0.0) | |
# 2B: Prediction function combines training inputs, ops, and variables, generates prediction | |
def predict_y_values(x_values): | |
# the definition of a line | |
return A * x_values + B | |
# The Training Loop | |
# start with simple 'for' loop | |
steps = 200 | |
for i in range(steps): | |
# training logic for each iteration: | |
with tfe.GradientTape() as tape: # tape is a programming language detail we'll get to | |
# use the model to generate a prediction | |
predicted_y_values = predict_y_values(x_values) | |
# 3. Loss computes error between prediction and desired output | |
loss_value = tf.reduce_mean(tf.square(predicted_y_values - y_values)) | |
# 5. Metrics tell you how well your model is training | |
if i % 20 == 0: | |
print("Loss at step {:03d}: {:.3f}".format(i, loss_value)) | |
# 4. Update step | |
# 4A: Gradient tells us which direction to change the variables to reduce loss | |
gradient_A, gradient_B = tape.gradient(loss_value, [A, B]) # Stay tuned! | |
# 4B: Nudge the variables by a small step in the right direction | |
A.assign_sub(gradient_A * 0.1) | |
B.assign_sub(gradient_B * 0.1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment