Skip to content

Instantly share code, notes, and snippets.

@thejevans
Created May 5, 2018 21:53
Show Gist options
  • Save thejevans/ff9ce77098b3848d0c263dda51635800 to your computer and use it in GitHub Desktop.
Save thejevans/ff9ce77098b3848d0c263dda51635800 to your computer and use it in GitHub Desktop.
epoch = 0
def next_batch(images, labels, start, batch_size):
global epoch
end = start + batch_size
if end > len(images):
# After each epoch we update this
epoch += 1
start = 0
end = batch_size
assert batch_size <= len(images)
return end, images[start:end], labels[start:end]
total_iterations = 0
def train(num_iteration):
global total_iterations
train_start = 0
test_start = 0
for i in range(total_iterations,
total_iterations + num_iteration):
train_start, x_batch, y_true_batch = next_batch(train_data, train_labels, train_start, batch_size)
test_start, x_valid_batch, y_valid_batch = next_batch(test_data, test_labels, test_start, batch_size)
feed_dict_tr = {x: x_batch,
y_true: y_true_batch}
feed_dict_val = {x: x_valid_batch,
y_true: y_valid_batch}
session.run(optimizer, feed_dict=feed_dict_tr)
if i % int(len(train_data)/batch_size) == 0:
val_loss = session.run(cost, feed_dict=feed_dict_val)
epoch = int(i / int(len(train_data)/batch_size))
if len(sys.argv) > 2:
show_progress(epoch, feed_dict_tr, feed_dict_val, val_loss)
if i == total_iterations + num_iteration - 1:
print(session.run(accuracy, feed_dict=feed_dict_val))
total_iterations += num_iteration
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment