Skip to content

Instantly share code, notes, and snippets.

@aflaag
Last active September 28, 2022 16:57
Show Gist options
  • Save aflaag/2420287544285091da8ffe9a26041ca7 to your computer and use it in GitHub Desktop.
Save aflaag/2420287544285091da8ffe9a26041ca7 to your computer and use it in GitHub Desktop.
from functools import partial
import jax.numpy as jnp
import jax
# STEP_SIZE = 0.0010
STEP_SIZES = jnp.array([1., 0.5, 0.1, 0.05, 0.01, 0.005, 0.001])
k = STEP_SIZES.shape[0]
STEP_SIZES = STEP_SIZES.reshape((1, k, 1))
key = jax.random.PRNGKey(1337)
key, init_key = jax.random.split(key)
xs = jnp.array([-2, 3, 6, -10, -5])
ys = jnp.array([5, 7, 8.2, 1.8, 3.8])
def generate_adjustments(key, n):
return jax.random.normal(key, (n, k, 2)) * STEP_SIZES
def loss(line):
w, b = line
return jnp.mean(jnp.square((w * xs + b - ys)))
@partial(jax.jit, static_argnames=['n'])
def step(seed, line, n):
adjs = generate_adjustments(seed, n)
adjusted = adjs + line
losses = jax.vmap(jax.vmap(loss))(adjusted)
index = jnp.unravel_index(jnp.argmin(losses), losses.shape)
return adjusted[index]
line = jax.random.uniform(key, (2,))
print(line)
for _ in range(3):
key, seed = jax.random.split(key)
line = step(seed, line, 1000)
print(line)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment