Skip to content

Instantly share code, notes, and snippets.

@rizar
Created September 16, 2018 17:42
Show Gist options
  • Save rizar/42834d8f2540c7c26f0bcbd551c4a6e8 to your computer and use it in GitHub Desktop.
Save rizar/42834d8f2540c7c26f0bcbd551c4a6e8 to your computer and use it in GitHub Desktop.
the code I used to plot alphas
def load_alpha_data(log_file):
data = [[], [], []]
with open(log_file) as log:
for line in log:
if line.startswith('data'):
num = int(line[5])
row = [float(x) for x in line[7:].split()]
data[num].append(row)
return data
def plot_all_alphas(data, only_3=False):
for k in range(3):
arr = numpy.array(data[k])
if only_3:
arr = arr[:, [4, 5, 7]]
arr = numpy.exp(arr) / numpy.exp(arr).sum(axis=1)[:, None]
pyplot.figure(figsize=(15, 3))
for i in range(arr.shape[1]):
pyplot.plot(arr[:, i])
if only_3:
pyplot.legend(['X', 'R', 'Y'])
pyplot.xlabel('alpha_{}'.format(k))
pyplot.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment