Skip to content

Instantly share code, notes, and snippets.

@willkurt
Created January 18, 2021 19:38
Show Gist options
  • Save willkurt/53db2e834ce84753aadb98d7485314ac to your computer and use it in GitHub Desktop.
Save willkurt/53db2e834ce84753aadb98d7485314ac to your computer and use it in GitHub Desktop.
# NOTE: This is gross, but it contains the plotting code to build this:
# https://twitter.com/willkurt/status/1351242237963874311
# ImageMagick does all the gif work: convert *.png -delay 60 -duplicate 20,22 logistic.gif
# This optimization loop is not idea for any other purpose than snapshotting the learning
lr = 0.000005
init_nll = nll(y_train,X_train,w)
next_nll = 0
img_ctr = 0
iters = 10
while (init_nll - next_nll) > 10:
init_nll = nll(y_train,X_train,w)
for _ in range(iters):
w -= lr*d_nll_wrt_w_c(y_train,X_train,w)
iters = int(iters**1.07)
next_nll = nll(y_train,X_train,w)
print(next_nll)
current_acc = (vmap(jnp.argmax)(y_train) == vmap(jnp.argmax)(softmax(jnp.dot(X_train,w)))).sum()/y_train.shape[0]
current_test =(vmap(jnp.argmax)(y_test) == vmap(jnp.argmax)(softmax(jnp.dot(X_test,w)))).sum()/y_test.shape[0]
current_nll = nll(y_train,X_train,w)
fig, ax = plt.subplots(2, 5, facecolor='w')
for i in range(10):
img = 1-w[:,i].reshape((28,28))
ax[i // 5][i % 5].imshow(img, cmap="Greys")
ax[i // 5][i % 5].set_title(i)
fig.subplots_adjust(wspace=0.7)
fig.suptitle("Using Logistic Regression on MNIST\nAccuracy (Test) {:0.1%} NLL: {:10.2f}".format(
current_test,
current_nll)
, fontsize=16,y=1.02)
fig.set_size_inches(10, 6)
img_ctr += 1
fig.savefig("./logistic_gif/{:03d}.png".format(img_ctr),
facecolor='w',
bbox_inches='tight')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment