Skip to content

Instantly share code, notes, and snippets.

@aflaag
Last active September 29, 2022 21:37
Show Gist options
  • Save aflaag/864734985670dc149fabc4408ccbf3e6 to your computer and use it in GitHub Desktop.
Save aflaag/864734985670dc149fabc4408ccbf3e6 to your computer and use it in GitHub Desktop.
from PIL import Image
import jax.numpy as jnp
import numpy as np
import jax
from IPython import display
STEP_SIZE = 1e-3
BETA = 0.9
BATCH_SIZE = 32
data = x_train
labels = y_train
def save_params(params):
for idx, param in enumerate(params):
jnp.save(f"classif/{idx}", param)
def generate_image(image_array):
reshaped = np.array(image_array).reshape((28, 28))
return Image.fromarray(reshaped)
def load_params(folder):
return [jnp.load(f"{folder}/{idx}.npy") for idx in range(6)]
def generate_params(key, layers_info):
params = []
for i in range(len(layers_info) - 1):
width = layers_info[i]
height = layers_info[i + 1]
key, w_seed = jax.random.split(key)
w = jax.random.normal(w_seed, (height, width)) * STEP_SIZE
key, b_seed = jax.random.split(key)
b = jax.random.normal(b_seed, (height, 1)) * STEP_SIZE
params.append(w)
params.append(b)
return params
def generate_momentum(layers_info):
momentum = []
for i in range(len(layers_info) - 1):
width = layers_info[i]
height = layers_info[i + 1]
w = jnp.zeros((height, width))
b = jnp.zeros((height, 1))
momentum.append(w)
momentum.append(b)
return momentum
def feed_forward(params, x):
inp = x
for i in range(0, len(params) - 2, 2):
w, b = params[i: i + 2]
inp = jax.nn.relu(w @ inp + b)
w, b = params[len(params) - 2 : len(params)]
return jax.nn.softmax(w @ inp + b, axis=0)
# return w @ inp + b
def cross_entropy(p, q):
return - jnp.sum(p * jnp.log(q + 1e-10))
def loss(params, x, y):
out = feed_forward(params, x)
return cross_entropy(y, out)
@jax.jit
def step(params, momentum, xs, ys):
batch_loss = lambda params, batch_x, batch_y : jax.vmap(loss, in_axes=(None, 0, 0))(params, batch_x, batch_y).mean()
loss_value, gradient = jax.value_and_grad(batch_loss)(params, xs, ys)
new_momentum = [m * BETA + g for m, g in zip(momentum, gradient)]
new_params = [p - m * STEP_SIZE for p, m in zip(params, momentum)]
return new_params, new_momentum, loss_value
key = jax.random.PRNGKey(0)
layers_info = [784, 32, 32, 10]
params = generate_params(key, layers_info)
momentum = generate_momentum(layers_info)
c = 0
for e in range(100):
key, seed = jax.random.split(key)
# this will yield the same permutation
data = jax.random.permutation(seed, data)
labels = jax.random.permutation(seed, labels)
for i in range(0, 60_000, BATCH_SIZE):
params, momentum, loss_value = step(params, momentum, data[i : i + BATCH_SIZE], labels[i : i + BATCH_SIZE])
if c == 1000:
print(e, loss_value)
c = 0
else:
c += 1
save_params(params)
save_params(params)
# params = load_params("classif")
giusti = 0
for i, (x, y) in enumerate(zip(x_test, y_test)):
key, seed = jax.random.split(key)
#random_image = jax.random.choice(seed, data)
#random_image_label = jax.random.choice(seed, labels)
out = feed_forward(params, x)
loss_v = cross_entropy(y, out)
out_value = jnp.argmax(out)
y_value = jnp.argmax(y)
if out_value == y_value:
giusti += 1
else:
generate_image(x * 255).convert('RGB').save(f"{i}.png")
print(i, out_value, y_value)
#print(loss_v, out_value, y_value)
print(giusti / 10000 * 100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment