Skip to content

Instantly share code, notes, and snippets.

@bayerj
Created February 12, 2020 14:03
Show Gist options
  • Save bayerj/6064edd404e65189105dbd9f7945b3d3 to your computer and use it in GitHub Desktop.
Save bayerj/6064edd404e65189105dbd9f7945b3d3 to your computer and use it in GitHub Desktop.
jax learning step wrt non-diff'able parameterds like integers.
import jax
import jax.experimental.optimizers
from jax.api import _check_inexact_input_vjp
from jax import tree_util as tu
import numpy as onp
def make_resilient_step(loss, sample_params, split, join, optimizer):
sample_learn_params, non_learn_params = split(sample_params)
def guarded_loss(learn_params, *args, **kwargs):
params = join(learn_params, non_learn_params)
return loss(params, *args, **kwargs)
d_guarded_loss = jax.grad(guarded_loss)
opt_init, opt_update, get_params = optimizer
opt_state = opt_init(sample_learn_params)
@jax.jit
def step(opt_state, *args, **kwargs):
params = get_params(opt_state)
loss = guarded_loss(params, *args, **kwargs)
g = d_guarded_loss(params, *args, **kwargs)
return loss, opt_update(1, g, opt_state)
return opt_state, step
def func(params):
return params["c"] ** 2 * params["d"]
def split(params):
def tell_include(leaf):
return jax.dtypes.issubdtype(
jax.core.get_aval(leaf).dtype, onp.inexact
)
flattened, tree_def = tu.tree_flatten(params)
to_include = [i for i in flattened if tell_include(i)]
to_exclude_and_idxs = [
(i, val) for i, val in enumerate(flattened) if not tell_include(val)
]
return to_include, (tree_def, to_exclude_and_idxs)
def join(included, excluded):
tree_def, excluded_leaves = excluded
all_leaves = included.copy()
for idx, leave in excluded_leaves:
all_leaves.insert(idx, leave)
return tu.tree_unflatten(tree_def, all_leaves)
optimizer = (
opt_init,
opt_update,
get_params,
) = jax.experimental.optimizers.adam(step_size=0.01)
initial_params = {"c": 2.0, "d": 1}
opt_state, step = make_resilient_step(
func,
initial_params,
split=split,
join=join,
optimizer=optimizer,
)
for i in range(1000):
l, opt_state = step(opt_state)
pars = get_params(opt_state)
final_pars = join(pars, split(initial_params)[1])
print(final_pars)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment