Created
July 14, 2017 08:01
-
-
Save hskang9/14f648601eb3a23604dd6e9c529935e0 to your computer and use it in GitHub Desktop.
Tensorflow Supervised Learning Boilerplate
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# initialize variables/model parameters | |
# define the training loop operations | |
def inference(X): | |
# compute inference model over data X nd return the result | |
def loss(X, Y): | |
# compute loss over training data X and expected outputs Y | |
def inputs(): | |
# read/generate input trading data X and expected outputs Y | |
def train(total_loss): | |
# train / adjust model parameters according to computed total loss | |
def evaluate(sess, X, Y): | |
# evaluate the resulting trained model | |
# Launch the graph in a session, setup boilerplate | |
with tf.Session() as sess: | |
tf.global_variables_initializer().run() | |
X, Y = inputs() | |
total_loss = loss(X, Y) | |
train_op = train(total_loss) | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(sess=sess, coord=coord) | |
# actual training loop | |
training_steps = 1000 | |
for step in trange(training_steps): | |
sess.run([train_op]) | |
# for debugging and learning purposes, see how the loss gets decremented through training steps | |
if step % 10 == 0: # for every 10 epochs | |
print "loss: ", sess.run([total_loss]) | |
# Evaluate the model | |
evaluate(sess, X, Y) | |
# Close training session | |
coord.request_stop() | |
coord.join(threads) | |
sess.close() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment