Created
July 20, 2019 16:31
-
-
Save usr-ein/68f10b08120133727b7f15ea56475d84 to your computer and use it in GitHub Desktop.
Autograd Implementation of packed ODE solver with augmented state
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
From Appendix D in the Neural ODE paper | |
Implementation of autograd | |
""" | |
import scipy.integrate | |
import autograd.numpy as np | |
from autograd.extend import primitive, defvjp_argnums | |
from autograd import make_vjp | |
from autograd.misc import flatten | |
from autograd.builtins import tuple | |
odeint = primitive(scipy.integrate.odeint) | |
def grad_odeint_all(yt, func, y0, t, func_args, **kwargs): | |
""" | |
Extended from "Scalable Inference of Ordinary Differential" | |
Equation Models of Biochemical Processes". Sec. 2.4.2 | |
Fabian Froehlich, Carolin Loos, Jan Hasenauer, 2017 | |
https://arxiv.org/pdf/1711.08079.pdf | |
""" | |
T, D = np.shape(yt) | |
flat_args, unflatten = flatten(func_args) | |
def flat_func(y, t, flat_args): | |
return func(y, t, *unflatten(flat_args)) | |
def unpack(x): | |
# y, vjp_y, vjp_t, vjp_args | |
return x[0:D], x[D:2 * D], x[2 * D], x[2 * D + 1:] | |
def augmented_dynamics(augmented_state, t, flat_args): | |
# Original system augemented with vjp_y, vjp_t and vjp_args | |
y, vjp_y, _, _ = unpack(augmented_state) | |
vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args) | |
def unpack(x): | |
# y, vjp_y, vjp_t, vjp_args | |
return x[0:D], x[D:2 * D], x[2 * D], x[2 * D + 1:] | |
def augmented_dynamics(augmented_state, t, flat_args): | |
# Original system augemented with vjp_y, vjp_t and vjp_args | |
y, vjp_y, _, _ = unpack(augmented_state) | |
vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args) | |
vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y) | |
return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args)) | |
def vjp_all(g, **kwargs): | |
vjp_y = g[-1, :] | |
vjp_t0 = 0 | |
time_vjp_list = [] | |
vjp_args = np.zeros(np.size(flat_args)) | |
for i in range(T - 1, 0, -1): | |
# Compute effect of moving current time. | |
vjp_cur_t = np.dot(func(yt[i, :], t[i], *func_args), g[i, :]) | |
time_vjp_list.append(vjp_cur_t) | |
vjp_t0 = vjp_t0 - vjp_cur_t | |
# Run augmented system backwards to the previous observation | |
aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args)) | |
aug_ans = odeint(augmented_dynamics, aug_y0, | |
np.array(t[i], t[i - 1]), tuple((flat_args,)), **kwargs) | |
_, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1]) | |
# Add gradient from current output | |
vjp_y = vjp_y + g[i - 1, :] | |
time_vjp_list.append(vjp_t0) | |
vjp_times = np.hstack(time_vjp_list)[::-1] | |
return None, vjp_y, vjp_times, unflatten(vjp_args) | |
return vjp_all | |
def grad_argnums_wrapper(all_vjp_builder): | |
""" | |
A generic autograd helper funciton. Takes a function that | |
builds vjps for all arguments, and wraps it to return only required vjps. | |
""" | |
def build_selected_vjps(argnums, ans, combined_args, kwargs): | |
vjp_func = all_vjp_builder(ans, *combined_args, **kwargs) | |
def chosen_vjps(g): | |
# Return whichever vjps were asked for | |
all_vjps = vjp_func(g) | |
return [all_vjps[argnum] for argnum in argnums] | |
return chosen_vjps | |
return build_selected_vjps | |
if __name__ == '__main__': | |
print(defvjp_argnums(odeint, grad_argnums_wrapper(grad_odeint_all))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment