Skip to content

Instantly share code, notes, and snippets.

@patrickmineault
Created April 23, 2021 22:07
Show Gist options
  • Save patrickmineault/347802fb2ac501a967b6cbbafdbf3e28 to your computer and use it in GitHub Desktop.
Save patrickmineault/347802fb2ac501a967b6cbbafdbf3e28 to your computer and use it in GitHub Desktop.
def jax_curve_fit(f, xdata, ydata, p0):
"""
Curve fit using jax. Similar interface to scipy.optimize.curve_fit
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html
"""
def logprob_fun(params, inputs, targets):
preds = f(inputs, *params)
return jnp.sum((preds - targets)**2)
grad_fun = jit(grad(logprob_fun))
p0 = np.array(p0)
E0 = logprob_fun(p0, xdata, ydata)
max_iter = 10000
max_backtracks = 10
growth = 2
backtrack = .1
alpha = .01
min_delta = 1e-4
for i in range(max_iter):
# Backtracking line search
g = grad_fun(p0, xdata, ydata)
alpha = alpha * growth
n = 0
while n < max_backtracks:
pp = p0 - alpha * g
E = logprob_fun(pp, xdata, ydata)
if E < E0:
break
alpha = alpha * backtrack
n += 1
p0 = pp
if E > E0 - min_delta:
break
E0 = E
return p0, None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment