Last active
June 28, 2017 05:21
-
-
Save ami-GS/50bb3995b337c1ec5b4222b3ee8f5dd1 to your computer and use it in GitHub Desktop.
caffeのプログレスバー
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
progress_bar.py | |
Copyright (c) [2017] [ami_GS] | |
This software is released under the MIT License. | |
http://opensource.org/licenses/mit-license.php | |
""" | |
from subprocess import Popen, PIPE | |
import sys, os | |
import time | |
def parse_solver(cmd): | |
# returns max_iter, display in prototxt | |
solver_prototxt = "" | |
for option in cmd: | |
if option.startswith('--solver='): | |
solver_prototxt = option[len('--solver='):] | |
break | |
max_iter = -1 # unknown | |
display = -1 # unknown | |
with open(solver_prototxt, 'r') as f: | |
lines = f.readlines() | |
for l in lines: | |
if "max_iter" in l: | |
max_iter = int(l.split(':')[1].strip()) | |
elif "display" in l: | |
display = int(l.split(':')[1].strip()) | |
return max_iter, display | |
class ProgressBar: | |
bar_len = 40 | |
def __init__(self, allNum): | |
self.template = "Iteration:%d-%d Loss:%f Time:%f" | |
maximum_cols = int(os.popen('stty size', 'r').read().split()[1]) | |
#self.bar_len = maximum_cols - (2+len(self.template)+4*2-4+8*2-4) - 5 | |
self.allStartTime = 0 | |
self.progress = 0.0 | |
self.one_block_progress = allNum/self.bar_len | |
self.allNum = allNum | |
def update(self, prvIterN, iterN, loss, delta): | |
self.progress += float(iterN-prvIterN)/self.one_block_progress | |
progress = int(round(self.progress)) | |
sys.stdout.write("["+"#"*progress + "."*(self.bar_len-progress)+"] " + self.template %(prvIterN, iterN, loss, delta)) | |
if progress != self.bar_len: | |
sys.stdout.write("\r") | |
else: | |
sys.stdout.write("\n%d iterations done in %f seconds" %(self.allNum, time.time() - self.allStartTime)) | |
def run_trace(cmd): | |
proc = Popen(" ".join(cmd), shell=True, stderr=PIPE) | |
max_iter, display = parse_solver(cmd) | |
p_bar = ProgressBar(max_iter) | |
starttime = 0 | |
iterNum = -1 | |
while True: | |
errout = proc.stderr.readline() | |
if "Iteration" in errout and "Testing" not in errout: | |
parsed = errout.split(']')[1].strip().replace(',', '').split(' ') | |
prevIterNum = int(iterNum) | |
iterNum = int(parsed[1]) | |
if parsed[2] == "lr" and prevIterNum == iterNum: | |
starttime = time.time() | |
if p_bar.allStartTime == 0: | |
p_bar.allStartTime = starttime | |
elif parsed[2] == "loss" and prevIterNum != iterNum and prevIterNum != -1: | |
p_bar.update(prevIterNum, iterNum, float(parsed[4]), time.time() - starttime) | |
if not errout and proc.poll() is not None: | |
break | |
if __name__ == "__main__": | |
args = sys.argv[1:] | |
run_trace(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
python progress_bar.py ./build/tools/caffe train --solver=OOOO_solver.prototxt
で使えるはず