Created
October 8, 2020 14:53
-
-
Save FrankGrimm/bd3d8e617ed2ef1cd709df85f350761c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import json | |
import os | |
import os.path | |
from glob import glob | |
basedir = "~/probds/lsh/logs/" | |
if len(sys.argv) > 1: | |
basedir = sys.argv[1] | |
print("[basedir] %s" % basedir) | |
LOGS=os.path.expanduser(os.path.join(basedir, "*.stderr")) | |
all_results = {} | |
max_epoch = -1 | |
for logfile in glob(LOGS): | |
print(logfile) | |
train_losses = [] | |
epoch_evals = {} | |
run_args = {} | |
with open(logfile, "rt") as infile: | |
for line in infile: | |
try: | |
line = json.loads(line) | |
except: | |
continue | |
if not 'type' in line: | |
continue | |
if line['type'] == 'train_loss': | |
train_losses.append(line['loss']) | |
continue | |
if line['type'] == 'eval': | |
epoch_evals[line['epoch']] = line['acc'] | |
continue | |
if line['type'] == 'args_input': | |
continue | |
if line['type'] == 'args': | |
run_args = line | |
continue | |
print('unhandled', line) | |
print(f"[run] {logfile}") | |
print(f"[meta] epochs: {run_args.get('num_epochs', None)} attention: {run_args.get('attention', None)}") | |
train_losses = train_losses[:5] + ["..."] + train_losses[-5:] | |
print(f"[train loss] %s" % (", ".join(map(str, train_losses)))) | |
attn_mechanism = run_args.get("attention", "noattn") or "noattn" | |
base_model = run_args.get("base_model", "unknown") | |
all_results["%s/%s" % (base_model, attn_mechanism)] = epoch_evals | |
for epoch, acc in epoch_evals.items(): | |
print(f"[epoch] {epoch}: {acc}") | |
if epoch > max_epoch: | |
max_epoch = epoch | |
print() | |
print() | |
print() | |
padlen = 40 | |
print("".ljust(padlen + 5, " "), "\t\t", "\t".join(map(str, [epoch for epoch in range(1, max_epoch+1)]))) | |
for attn_mechanism, eval_results in all_results.items(): | |
if attn_mechanism == "unknown/noattn": | |
continue | |
perepoch = map(str, [eval_results.get(epoch, "") for epoch in range(1, max_epoch+1)]) | |
print(attn_mechanism[:padlen].ljust(40, " "), "\t\t", "\t".join(perepoch)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment