Created
January 26, 2017 08:10
-
-
Save liruoteng/4a4e9ed9480ca375028b2f4cebcc9768 to your computer and use it in GitHub Desktop.
caffe plot training loss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import re | |
import sys | |
import matplotlib.pyplot as plt | |
TRAIN_LOSS_PATTERN = r"Iteration (\d+), loss = (\d+\.\d*)" | |
TEST_LOSS_PATTERN = r"Iteration (\d+), Testing net \(#0\)\n.*= (\d+\.\d+) loss\)\n.*= (\d+\.\d+) loss\)\n.*\n.*\n.*\n.*\n.*\n.*= (\d+\.\d+) loss\)\n.*= (\d+\.\d+) loss\)\n.*= (\d+\.\d+) loss\)\n.*= (\d+\.\d+) loss\)" | |
def main(): | |
if len(sys.argv) > 1: | |
log_file_name = sys.argv[1] | |
else: | |
raise("please provide log file to process") | |
log_file = open(log_file_name, 'r') | |
log_data = log_file.read() | |
training_result = re.findall(TRAIN_LOSS_PATTERN,log_data) | |
testing_result = re.findall(TEST_LOSS_PATTERN, log_data) | |
train_iter = [] | |
train_loss = [] | |
test_iter = [] | |
test_loss = [] | |
test_loss_length = len(testing_result[0]) - 1 | |
for train in training_result: | |
train_iter.append(int(train[0])) | |
train_loss.append(float(train[1])) | |
for test in testing_result: | |
test_iter.append(int(test[0])) | |
temp_loss = 0 | |
for i in range(test_loss_length): | |
temp_loss += float(test[i+1]) | |
test_loss.append(temp_loss) | |
print test_iter | |
print test_loss | |
# display | |
plt.plot(train_iter, train_loss, 'k', label='Train loss') | |
plt.plot(test_iter, test_loss, 'r', label='Test loss') | |
plt.legend() | |
plt.ylabel('Loss') | |
plt.xlabel('Epoch') | |
plt.savefig('loss.png') | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment