Skip to content

Instantly share code, notes, and snippets.

@khanhnamle1994
Created March 12, 2019 13:43
Show Gist options
  • Select an option

  • Save khanhnamle1994/1b9fe5a15b62d1e57318f1f95feb295d to your computer and use it in GitHub Desktop.

Select an option

Save khanhnamle1994/1b9fe5a15b62d1e57318f1f95feb295d to your computer and use it in GitHub Desktop.
def optimize(num_iterations, X):
global total_iterations
start_time = time.time()
#array to plot
losses = {'train':[], 'validation':[]}
for i in range(num_iterations):
total_iterations += 1
# Get a batch of training examples.
# x_batch now holds a batch of images and
# y_true_batch are the true labels for those images.
x_batch, y_true_batch = next_batch(batch_size, X_train, y_train)
# Put the batch into a dict with the proper names
# for placeholder variables in the TensorFlow graph.
feed_dict_train = {x: x_batch,
y_true: y_true_batch,
keep_prob_conv : 0.3,
keep_prob_fc : 0.4}
feed_dict_validation = {x: X_validation,
y_true: y_validation,
keep_prob_conv : 1,
keep_prob_fc : 1}
# Run the optimizer using this batch of training data.
# TensorFlow assigns the variables in feed_dict_train
# to the placeholder variables and then runs the optimizer.
session.run(optimizer, feed_dict=feed_dict_train)
acc_train = session.run(accuracy, feed_dict=feed_dict_train)
acc_validation = session.run(accuracy, feed_dict=feed_dict_validation)
losses['train'].append(acc_train)
losses['validation'].append(acc_validation)
# Print status every X iterations.
if (total_iterations % X == 0) or (i ==(num_iterations -1)):
# Calculate the accuracy on the training-set.
msg = "Iteration: {0:>6}, Training Accuracy: {1:>6.1%}, Validation Accuracy: {2:>6.1%}"
print(msg.format(total_iterations, acc_train, acc_validation))
# Ending time.
end_time = time.time()
# Difference between start and end-times.
time_dif = end_time - start_time
# Print the time-usage.
print("Time usage: " + str(timedelta(seconds=int(round(time_dif)))))
plt.plot(losses['train'], label='Training loss')
plt.plot(losses['validation'], label='Validation loss')
plt.legend()
_ = plt.ylim()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment