Created
March 12, 2019 13:43
-
-
Save khanhnamle1994/1b9fe5a15b62d1e57318f1f95feb295d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def optimize(num_iterations, X): | |
| global total_iterations | |
| start_time = time.time() | |
| #array to plot | |
| losses = {'train':[], 'validation':[]} | |
| for i in range(num_iterations): | |
| total_iterations += 1 | |
| # Get a batch of training examples. | |
| # x_batch now holds a batch of images and | |
| # y_true_batch are the true labels for those images. | |
| x_batch, y_true_batch = next_batch(batch_size, X_train, y_train) | |
| # Put the batch into a dict with the proper names | |
| # for placeholder variables in the TensorFlow graph. | |
| feed_dict_train = {x: x_batch, | |
| y_true: y_true_batch, | |
| keep_prob_conv : 0.3, | |
| keep_prob_fc : 0.4} | |
| feed_dict_validation = {x: X_validation, | |
| y_true: y_validation, | |
| keep_prob_conv : 1, | |
| keep_prob_fc : 1} | |
| # Run the optimizer using this batch of training data. | |
| # TensorFlow assigns the variables in feed_dict_train | |
| # to the placeholder variables and then runs the optimizer. | |
| session.run(optimizer, feed_dict=feed_dict_train) | |
| acc_train = session.run(accuracy, feed_dict=feed_dict_train) | |
| acc_validation = session.run(accuracy, feed_dict=feed_dict_validation) | |
| losses['train'].append(acc_train) | |
| losses['validation'].append(acc_validation) | |
| # Print status every X iterations. | |
| if (total_iterations % X == 0) or (i ==(num_iterations -1)): | |
| # Calculate the accuracy on the training-set. | |
| msg = "Iteration: {0:>6}, Training Accuracy: {1:>6.1%}, Validation Accuracy: {2:>6.1%}" | |
| print(msg.format(total_iterations, acc_train, acc_validation)) | |
| # Ending time. | |
| end_time = time.time() | |
| # Difference between start and end-times. | |
| time_dif = end_time - start_time | |
| # Print the time-usage. | |
| print("Time usage: " + str(timedelta(seconds=int(round(time_dif))))) | |
| plt.plot(losses['train'], label='Training loss') | |
| plt.plot(losses['validation'], label='Validation loss') | |
| plt.legend() | |
| _ = plt.ylim() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment