|
import matplotlib.pyplot as plt |
|
|
|
def parse_logfile(logfile): |
|
# so the tricky part we have to deal with in these log files |
|
# is that the job could crash and get restarted, which will |
|
# re-wind back and start re-logging older steps. So we keep |
|
# all the data as dictionary and over-write old data with new |
|
# and then at the end compile everything together |
|
|
|
# read raw data |
|
streams = {} # stream:str -> {step: val} |
|
with open(logfile, "r") as f: |
|
for line in f: |
|
parts = line.split() |
|
step = int(parts[0].split(":")[1]) |
|
stream = parts[1].split(":")[0] |
|
val = float(parts[1].split(":")[1]) |
|
if not stream in streams: |
|
streams[stream] = {} |
|
d = streams[stream] |
|
d[step] = val |
|
# now re-represent as list of (step, val) tuples |
|
streams_xy = {} |
|
for k, v in streams.items(): |
|
# get all (step, val) items, sort them |
|
xy = sorted(list(v.items())) |
|
# unpack the list of tuples to tuple of lists |
|
streams_xy[k] = zip(*xy) |
|
# return the xs, ys lists |
|
return streams_xy |
|
|
|
import numpy as np |
|
|
|
sz = "124M" |
|
loss_baseline = { |
|
"124M": 3.424958, |
|
"350M": 3.083089, |
|
"774M": 3.000580, |
|
"1558M": 2.831273, |
|
}[sz] |
|
hella2_baseline = { # for GPT-2 |
|
"124M": 0.294463, |
|
"350M": 0.375224, |
|
"774M": 0.431986, |
|
"1558M": 0.488946, |
|
}[sz] |
|
hella3_baseline = { # for GPT-3 |
|
"124M": 0.337, |
|
"350M": 0.436, |
|
"774M": 0.510, |
|
"1558M": 0.547, |
|
}[sz] |
|
# assumes each model run is stored in this way |
|
logfile = f"../log{sz}/main.log" |
|
streams = parse_logfile(logfile) |
|
|
|
plt.figure(figsize=(16, 6)) |
|
|
|
# Panel 1: losses: both train and val |
|
plt.subplot(121) |
|
xs, ys = streams["trl"] # training loss |
|
xs = np.array(xs) |
|
ys = np.array(ys) |
|
|
|
plt.plot(xs, ys, label=f'llm.c ({sz}) train loss') |
|
print("Min Train Loss:", min(ys)) |
|
xs, ys = streams["tel"] # validation loss |
|
plt.plot(xs, ys, label=f'llm.c ({sz}) val loss') |
|
# horizontal line at GPT-2 baseline |
|
# we don't have GPT-3 loss on this dataset because the weights were never released |
|
if loss_baseline is not None: |
|
plt.axhline(y=loss_baseline, color='r', linestyle='--', label=f"OpenAI GPT-2 ({sz}) checkpoint val loss") |
|
plt.xlabel("steps") |
|
plt.ylabel("loss") |
|
plt.yscale('log') |
|
plt.ylim(top=11.0) |
|
plt.legend() |
|
plt.title("Loss") |
|
print("Min Validation Loss:", min(ys)) |
|
|
|
# Panel 2: HellaSwag eval |
|
plt.subplot(122) |
|
if "eval" in streams: |
|
xs, ys = streams["eval"] # HellaSwag eval |
|
ys = np.array(ys) |
|
plt.plot(xs, ys, label=f"llm.c ({sz})") |
|
# horizontal line at GPT-2/3 baselines |
|
if hella2_baseline: |
|
plt.axhline(y=hella2_baseline, color='r', linestyle='--', label=f"OpenAI GPT-2 ({sz}) checkpoint") |
|
if hella3_baseline: |
|
plt.axhline(y=hella3_baseline, color='g', linestyle='--', label=f"OpenAI GPT-3 ({sz}) checkpoint") |
|
plt.xlabel("steps") |
|
plt.ylabel("accuracy") |
|
plt.legend() |
|
plt.title("HellaSwag eval") |
|
print("Max Hellaswag eval:", max(ys)) |
|
|
|
print("saving plots to dev/plot.png") |
|
plt.savefig('plot.png') |