Created
May 5, 2018 21:53
-
-
Save thejevans/ff9ce77098b3848d0c263dda51635800 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
epoch = 0 | |
def next_batch(images, labels, start, batch_size): | |
global epoch | |
end = start + batch_size | |
if end > len(images): | |
# After each epoch we update this | |
epoch += 1 | |
start = 0 | |
end = batch_size | |
assert batch_size <= len(images) | |
return end, images[start:end], labels[start:end] | |
total_iterations = 0 | |
def train(num_iteration): | |
global total_iterations | |
train_start = 0 | |
test_start = 0 | |
for i in range(total_iterations, | |
total_iterations + num_iteration): | |
train_start, x_batch, y_true_batch = next_batch(train_data, train_labels, train_start, batch_size) | |
test_start, x_valid_batch, y_valid_batch = next_batch(test_data, test_labels, test_start, batch_size) | |
feed_dict_tr = {x: x_batch, | |
y_true: y_true_batch} | |
feed_dict_val = {x: x_valid_batch, | |
y_true: y_valid_batch} | |
session.run(optimizer, feed_dict=feed_dict_tr) | |
if i % int(len(train_data)/batch_size) == 0: | |
val_loss = session.run(cost, feed_dict=feed_dict_val) | |
epoch = int(i / int(len(train_data)/batch_size)) | |
if len(sys.argv) > 2: | |
show_progress(epoch, feed_dict_tr, feed_dict_val, val_loss) | |
if i == total_iterations + num_iteration - 1: | |
print(session.run(accuracy, feed_dict=feed_dict_val)) | |
total_iterations += num_iteration |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment