Last active
November 9, 2018 21:07
-
-
Save michaelosthege/a75b565d3f653721fa235a07eb089912 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Sampling parameters of a lorenz attractor. | |
The forward pass integrates the lorenz attractor ODE system using | |
tt.scan with a Runge-Kutta integrator. The predicted high-resolution | |
timecourse is interpolated down so it can be compared to low-density | |
observations. | |
""" | |
import abc | |
import numpy | |
import pymc3 | |
from matplotlib import pyplot | |
import theano | |
import theano.tensor as tt | |
import scipy.integrate | |
def interpolation_weights(x_predicted, x_interpolated): | |
"""Computes weights for use in left-handed dot product with y_fix. | |
Args: | |
x_predicted (numpy.ndarray): x-values at which Y is predicted | |
x_interpolated (numpy.ndarray): x-values for which Y is desired | |
Returns: | |
weights (numpy.ndarray): weights for Y_desired = dot(weights, Y_predicted) | |
""" | |
x_repeat = numpy.tile(x_interpolated[:,None], (len(x_predicted),)) | |
distances = numpy.abs(x_repeat - x_predicted) | |
x_indices = numpy.searchsorted(x_predicted, x_interpolated) | |
weights = numpy.zeros_like(distances) | |
idx = numpy.arange(len(x_indices)) | |
weights[idx,x_indices] = distances[idx,x_indices-1] | |
weights[idx,x_indices-1] = distances[idx,x_indices] | |
weights /= numpy.sum(weights, axis=1)[:,None] | |
return weights | |
class Integrator(object): | |
"""Abstract class of an ODE solver to be used with Theano scan.""" | |
__metaclass__ = abc.ABCMeta | |
def step(self, t, dt, y, dydt, theta): | |
"""Symbolic integration step. | |
Args: | |
t (TensorVariable): timepoint to be passed to [dydt] | |
dt (TensorVariable): stepsize | |
y (TensorVariable): vector of current states | |
dydt (callable): dydt function of the model like dydt(y, t, theta) | |
theta (TensorVariable): system parameters | |
Returns: | |
TensorVariable: yprime | |
""" | |
raise NotImplementedError() | |
class RungeKutta(Integrator): | |
def step(self, t, dt, y, dydt, theta): | |
k1 = dt*dydt(y, t, theta) | |
k2 = dt*dydt(y + 0.5*k1, t, theta) | |
k3 = dt*dydt(y + 0.5*k2, t, theta) | |
k4 = dt*dydt(y + k3, t, theta) | |
y_np1 = y + (1./6.)*k1 + (1./3.)*k2 + (1./3.)*k3 + (1./6.)*k4 | |
return y_np1 | |
class TheanoIntegrationOps(object): | |
"""This is not actually a real Op, but it can be used as if. | |
It does differentiable solving of a dynamic system using the provided 'step_theano' method. | |
When called, it essentially creates all the steps in the computation graph to get from y0/theta | |
to Y_hat. | |
""" | |
def __init__(self, dydt_theano, integrator:Integrator): | |
"""Creates an Op that uses the [integrator] to solve [dydt]. | |
Args: | |
dydt_theano (callable): function that computes the first derivative of the system | |
integrator (Integrator): integrator to use for solving | |
""" | |
self.dydt_theano = dydt_theano | |
self.integrator = integrator | |
return super().__init__() | |
def __step_theano(self, t, y_t, dt_t, t_theta): | |
"""Step method that will be used in tt.scan. | |
Uses the integrator to give a better approximation than dydt alone. | |
Args: | |
t (TensorVariable): time since intial state | |
y_t (TensorVariable): current state of the system | |
dt_t (TensorVariable): stepsize | |
t_theta (TensorVariable): system parameters | |
Returns: | |
TensorVariable: change in y at time t | |
""" | |
return self.integrator.step(t, dt_t, y_t, self.dydt_theano, t_theta) | |
def __call__(self, y0, theta, dt, n): | |
"""Creates the computation graph for solving the ODE system. | |
Args: | |
y0 (TensorVariable or array): initial system state | |
theta (TensorVariable or array): system parameters | |
dt (float): fixed stepsize for solving | |
n (int): number of solving iterations | |
Returns: | |
TensorVariable: system state y for all t in numpy.arange(0, dt*n) with shape (len(y0),n) | |
""" | |
# TODO: check dtypes, stack and raise warnings | |
t_y0 = tt.as_tensor_variable(y0) | |
t_theta = tt.as_tensor_variable(theta) | |
Y_hat, updates = theano.scan(fn=self.__step_theano, | |
outputs_info =[{'initial':t_y0}], | |
sequences=[theano.tensor.arange(dt, dt*n, dt)], | |
non_sequences=[dt, t_theta], | |
n_steps=n-1) | |
# scan does not return y0, so it must be concatenated | |
Y_hat = tt.concatenate((t_y0[None,:], Y_hat)) | |
# return as (len(y0),n) | |
Y_hat = tt.transpose(Y_hat) | |
return Y_hat | |
class InterpolationOps(object): | |
"""Linearly interpolates the entries in a tensor according to vectors of | |
predicted and desired coordinates. | |
""" | |
def __init__(self, x_predicted, x_interpolated): | |
"""Prepare an interpolation subgraph. | |
Args: | |
x_predicted (ndarray): x-coordinates for which Y will be predicted (T_pred,) | |
x_interpolated (ndarray): x-coordinates for which Y is desired (T_data,) | |
""" | |
assert x_interpolated[-1] <= x_predicted[-1], "x_predicted[-1]={} but " \ | |
"x_interpolated[-1]={}".format(x_predicted[-1], x_interpolated[-1]) | |
self.x_predicted = x_predicted | |
self.x_interpolated = x_interpolated | |
self.weights = tt.as_tensor_variable( | |
interpolation_weights(x_predicted, x_interpolated)) | |
return super().__init__() | |
def __call__(self, Y_predicted): | |
"""Symbolically apply interpolation. | |
Args: | |
Y_predicted (ndarray or TensorVariable): predictions at x_pred with shape (N_Y,T_pred) | |
Returns: | |
Y_interpolated (TensorVariable): interpolated predictions at x_data with shape (N_Y,T_data) | |
""" | |
Y_predicted = tt.as_tensor_variable(Y_predicted) | |
Y_interpolated = tt.dot(self.weights, tt.transpose(Y_predicted)) | |
return tt.transpose(Y_interpolated) | |
def dydt(y, t, theta): | |
sigma, rho, beta = theta | |
yprime = [ | |
sigma*(y[1] - y[0]), | |
y[0]*(rho - y[2]) - y[1], | |
y[0]*y[1] - beta*y[2] | |
] | |
return yprime | |
def dydt_theano(y, t, theta): | |
# get parameters | |
sigma = theta[0] | |
rho = theta[1] | |
beta = theta[2] | |
# set up differential equations | |
yprime = tt.zeros_like(y) | |
yprime = tt.set_subtensor(yprime[0], sigma*(y[1] - y[0]) ) | |
yprime = tt.set_subtensor(yprime[1], y[0]*(rho - y[2]) - y[1]) | |
yprime = tt.set_subtensor(yprime[2], y[0]*y[1] - beta*y[2] ) | |
return yprime | |
def run(): | |
# x, y, z, a, b, c | |
truth = numpy.array([2.2, 5.3, 7.4, 9.5, 27, 9/3]) | |
x = numpy.linspace(0, 1, 20) | |
y = y = scipy.integrate.odeint(dydt, truth[:3], x, (truth[3:],)) | |
#pyplot.plot(x,y) | |
#pyplot.show() | |
dt = 0.001 | |
integrator = RungeKutta() | |
pmodel = pymc3.Model() | |
with pmodel: | |
# Priors | |
x0 = pymc3.Uniform('x0', -5.01, 10.01, testval=2.0) | |
y0 = pymc3.Uniform('y0', -2.01, 13.01, testval=5.0) | |
z0 = pymc3.Uniform('z0', 0.01, 15.01, testval=7.0) | |
a = pymc3.Normal('a', mu=10.0, sd=1.1) | |
b = pymc3.Normal('b', mu=28.0, sd=1.2) | |
c = pymc3.Normal('c', mu=8/3 , sd=0.3) | |
T_y0 = [x0, y0, z0] | |
T_theta = [a, b, c] | |
# Prediction | |
x_max = x[-1] | |
x_any = x | |
n_any = len(x_any) | |
n_pred = int(numpy.ceil(x_max / dt)+1) | |
Y_hat_TV = TheanoIntegrationOps(dydt_theano, integrator)(T_y0, T_theta, dt, n_pred) | |
# theano implementations only predicts in these regular intervals: | |
x_pred = numpy.linspace(0, dt*n_pred, n_pred) | |
# and usually they must be interpolated to the observation timepoints | |
if not numpy.array_equal(x_pred, x_any): | |
Y_hat_TV = InterpolationOps(x_pred, x_any)(Y_hat_TV) | |
# loglikelihood | |
for i, label in enumerate(['x', 'y', 'z']): | |
L = pymc3.Normal(label + '_obs', mu=Y_hat_TV[i], sd=0.1, observed=y.T[i]) | |
with pmodel: | |
# IMPORTANT: must init=advi for performance reasons (see https://github.com/pymc-devs/pymc3/issues/2753) | |
nutstrace = pymc3.sample(chains=1, njobs=1, init='advi') | |
if __name__ == '__main__': | |
print('pymc3 version {}'.format(pymc3.__version__)) | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment