Created
January 18, 2021 19:38
-
-
Save willkurt/53db2e834ce84753aadb98d7485314ac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# NOTE: This is gross, but it contains the plotting code to build this: | |
# https://twitter.com/willkurt/status/1351242237963874311 | |
# ImageMagick does all the gif work: convert *.png -delay 60 -duplicate 20,22 logistic.gif | |
# This optimization loop is not idea for any other purpose than snapshotting the learning | |
lr = 0.000005 | |
init_nll = nll(y_train,X_train,w) | |
next_nll = 0 | |
img_ctr = 0 | |
iters = 10 | |
while (init_nll - next_nll) > 10: | |
init_nll = nll(y_train,X_train,w) | |
for _ in range(iters): | |
w -= lr*d_nll_wrt_w_c(y_train,X_train,w) | |
iters = int(iters**1.07) | |
next_nll = nll(y_train,X_train,w) | |
print(next_nll) | |
current_acc = (vmap(jnp.argmax)(y_train) == vmap(jnp.argmax)(softmax(jnp.dot(X_train,w)))).sum()/y_train.shape[0] | |
current_test =(vmap(jnp.argmax)(y_test) == vmap(jnp.argmax)(softmax(jnp.dot(X_test,w)))).sum()/y_test.shape[0] | |
current_nll = nll(y_train,X_train,w) | |
fig, ax = plt.subplots(2, 5, facecolor='w') | |
for i in range(10): | |
img = 1-w[:,i].reshape((28,28)) | |
ax[i // 5][i % 5].imshow(img, cmap="Greys") | |
ax[i // 5][i % 5].set_title(i) | |
fig.subplots_adjust(wspace=0.7) | |
fig.suptitle("Using Logistic Regression on MNIST\nAccuracy (Test) {:0.1%} NLL: {:10.2f}".format( | |
current_test, | |
current_nll) | |
, fontsize=16,y=1.02) | |
fig.set_size_inches(10, 6) | |
img_ctr += 1 | |
fig.savefig("./logistic_gif/{:03d}.png".format(img_ctr), | |
facecolor='w', | |
bbox_inches='tight') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment