Last active
February 19, 2018 14:04
-
-
Save talolard/6c45f439a7524ef6b894d6225f6bbb58 to your computer and use it in GitHub Desktop.
Example of using dataset and iterators to the train and val
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__=="__main__": | |
#make the iterators and next element op | |
next_element, training_init_op, validation_init_op = prepare_dataset_iterators(batch_size=32) | |
... | |
for epoch in range(1000): | |
#Initialize the iterator to consume training data | |
sess.run(training_init_op) | |
while True: | |
#As long as the iterator is not empty | |
try: | |
_, summary,gs = sess.run([M.train,M.write_op,M.gs],feed_dict={M.lr: lr, M.keep_prob:keep_prob}) | |
except tf.errors.OutOfRangeError: | |
#Do stuff at the end of a training epoch here | |
break | |
#Intiialize the iterator to provide validation data | |
sess.run(validation_init_op) | |
#We'll store the losses from each batch to get an average | |
while True: | |
# As long as the iterator is not empty | |
try: | |
loss,summary,gs,_ = sess.run([M.total_loss,M.write_op,M.gs,M.increment_gs],feed_dict={M.lr: lr,M.keep_prob:1}) | |
except tf.errors.OutOfRangeError: | |
#Do stuff at the end of a validation run here | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment