Last active
October 2, 2018 12:01
-
-
Save goldingn/bea9e3dac12cbcb53e4e3fa8456d77bc to your computer and use it in GitHub Desktop.
A prototype and demonstration of solving ODEs with greta
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# prototype ODE solver function for greta | |
# user-facing function to export: | |
# derivative must be a function with the first two arguments being 'y' and 't', | |
# and subsequent named arguments representing (temporally static) model | |
# parameters | |
# y0 must be a greta array representing the shape of y at time 0 | |
# times must be a column vector of times at which to evaluate y | |
# dots must be named greta arrays for the additional (fixed) parameters | |
ode_solve <- function (derivative, y0, times, ...) { | |
times <- as.greta_array(times) | |
# check times is a column vector | |
t_dim <- dim(times) | |
if (length(t_dim != 2) && t_dim[2] != 1) { | |
stop("", | |
call. = FALSE) | |
} | |
dots <- list(...) | |
dots <- lapply(dots, as.greta_array) | |
# check all additional parameters are named and there are no extras | |
# check derivative is a function | |
# check the arguments of derivative are valid and match dots | |
# create a tensorflow version of the function | |
tf_derivative <- as_tf_derivative(derivative, y0, times, dots) | |
# the dimensions should be the dimensions of the y0, duplicated along times | |
n_time <- dim(times)[1] | |
y0_dims <- dim(y0) | |
# drop the first element if it's a one | |
if (y0_dims[1] == 1) | |
y0_dims <- y0_dims[-1] | |
dims <- c(n_time, y0_dims) | |
op("ode", y0, times, ..., | |
dim = dims, | |
tf_operation = "tf_ode_solve", | |
operation_args = list(tf_derivative = tf_derivative)) | |
} | |
# internal tf function wrapping the core TF method | |
# return a tensor for the integral of derivative evaluated at times, given | |
# starting state y0 and other parameters dots | |
tf_ode_solve <- function(y0, times, ..., tf_derivative) { | |
# drop the columns and batch dimension in times | |
times <- tf_flatten(times) | |
times <- tf$slice(times, c(0L, 0L), c(1L, -1L)) | |
times <- tf$squeeze(times, 0L) | |
# assign the dots (as tensors) to the function environment | |
assign("tf_dots", list(...), | |
environment(tf_derivative)) | |
# integrate - need to run this with the dag's TF graph as default, so that | |
# tf_derivative creates tensors correctly | |
dag <- parent.frame()$dag | |
dag$on_graph( | |
integral <- tf$contrib$integrate$odeint(tf_derivative, y0, times) | |
) | |
# reshape to put batch dimension first | |
permutation <- seq_along(dim(integral)) - 1L | |
permutation[1:2] <- permutation[2:1] | |
integral <- tf$transpose(integral, perm = permutation) | |
# if the first (non-batch) dimension of y0 was 1, drop it in the results | |
if (dim(y0)[[2]] == 1) | |
integral <- tf$squeeze(integral, 2L) | |
integral | |
} | |
# given a greta/R function derivative function, and greta arrays for the inputs, | |
# return a tensorflow function taking tensors for y and t and returning a tensor | |
# for dydt | |
as_tf_derivative <- function (derivative, y, t, dots) { | |
# create a function acting on the full set of inputs, as tensors | |
args <- list(r_fun = derivative, y = y, t = t) | |
tf_fun <- do.call(as_tf_function, c(args, dots)) | |
# for CRAN's benefit | |
tf_dots <- NULL | |
# return a function acting only on tensors y and t, to feed to the ode solver | |
function (y, t) { | |
# tf_dots will have been added to this environment by tf_ode_solve | |
args <- list(y = y, t = t) | |
do.call(tf_fun, c(args, tf_dots)) | |
} | |
} | |
# ~~~~~~~~ | |
# test run | |
library(greta) | |
# load required internal methods | |
as.greta_array <- .internals$greta_arrays$as.greta_array | |
op <- .internals$nodes$constructors$op | |
tf_flatten <- greta:::tf_flatten | |
as_tf_function <- .internals$utils$greta_array_operations$as_tf_function | |
library(tensorflow) | |
# simulate data using the Lotka-Volterra example from deSolve | |
set.seed(2018-03-14) | |
library (deSolve) | |
LVmod <- function(Time, State, Pars) { | |
with(as.list(c(State, Pars)), { | |
Ingestion <- rIng * Prey * Predator | |
GrowthPrey <- rGrow * Prey * (1 - Prey / K) | |
MortPredator <- rMort * Predator | |
dPrey <- GrowthPrey - Ingestion | |
dPredator <- Ingestion * assEff - MortPredator | |
return (list(c(dPrey, dPredator))) | |
}) | |
} | |
pars <- c(rIng = 0.2, # /day, rate of ingestion | |
rGrow = 1.0, # /day, growth rate of prey | |
rMort = 0.2, # /day, mortality rate of predator | |
assEff = 0.5, # -, assimilation efficiency | |
K = 10) # mmol/m3, carrying capacity | |
yini <- c(Prey = 1, Predator = 2) | |
times <- seq(0, 50, by = 1) | |
out <- ode(yini, times, LVmod, pars) | |
# simulate observations | |
jitter <- rnorm(2 * length(times), 0, 0.1) | |
y_obs <- out[, -1] + matrix(jitter, ncol = 2) | |
# ~~~~~~~~~ | |
# fit a greta model to infer the parameters from this simulated data | |
# greta version of the function | |
lotka_volterra <- function(y, t, rIng, rGrow, rMort, assEff, K) { | |
Prey <- y[1, 1] | |
Predator <- y[1, 2] | |
Ingestion <- rIng * Prey * Predator | |
GrowthPrey <- rGrow * Prey * (1 - Prey / K) | |
MortPredator <- rMort * Predator | |
dPrey <- GrowthPrey - Ingestion | |
dPredator <- Ingestion * assEff - MortPredator | |
cbind(dPrey, dPredator) | |
} | |
# priors for the parameters | |
rIng <- uniform(0, 0.4) # /day, rate of ingestion | |
rGrow <- uniform(0.8, 1.2) # /day, growth rate of prey | |
rMort <- uniform(0, 0.4) # /day, mortality rate of predator | |
assEff <- uniform(0.25, 0.75) # -, assimilation efficiency | |
K <- uniform(8, 12) # mmol/m3, carrying capacity | |
# initial values and observation error | |
y0 <- uniform(0.5, 1.5, dim = c(1, 2)) + t(0:1) | |
obs_sd <- uniform(0, 0.5) | |
# solution to the ODE | |
y <- ode_solve(lotka_volterra, y0, times, rIng, rGrow, rMort, assEff, K) | |
# sampling statement/observation model | |
distribution(y_obs) <- normal(y, obs_sd) | |
# we can use greta to solve directly, for a fixed set of parameters (the true | |
# ones in this case) | |
values <- c(list(y0 = t(1:2)), | |
as.list(pars)) | |
vals <- calculate(y, values) | |
plot(vals[, 1] ~ times, type = "l", ylim = range(vals)) | |
lines(vals[, 2] ~ times, lty = 2) | |
points(y_obs[, 1] ~ times) | |
points(y_obs[, 2] ~ times, pch = 2) | |
# or we can do inference on the parameters: | |
# build the model (takes a few seconds to define the tensorflow graph) | |
m <- model(rIng, rGrow, rMort, assEff, K, obs_sd) | |
# run MCMC on it - warning: this takes a lot of CPU-time, especialy with HMC!! | |
draws <- mcmc(m, sampler = rwmh()) | |
plot(draws) | |
# we can get predictive posteriors for the solution, even after sampling the other | |
# parameters | |
y_upper <- y + 1.96 * obs_sd | |
y_lower <- y - 1.96 * obs_sd | |
y_upper_draws <- calculate(y_upper, draws) | |
y_lower_draws <- calculate(y_lower, draws) | |
y_draws <- calculate(y, draws) | |
y_upper_mean <- summary(y_upper_draws)$statistics[, "Mean"] | |
y_lower_mean <- summary(y_lower_draws)$statistics[, "Mean"] | |
y_mean <- summary(y_draws)$statistics[, "Mean"] | |
# plot the median predictions and intervals | |
prey_idx <- seq_along(times) | |
predator_idx <- prey_idx + length(prey_idx) | |
plot(y_mean[prey_idx] ~ times, | |
type = "l", | |
lwd = 2, | |
ylim = range(quants)) | |
lines(y_upper_mean[prey_idx] ~ times, | |
lty = 2) | |
lines(y_lower_mean[prey_idx] ~ times, | |
lty = 2) | |
lines(y_mean[predator_idx] ~ times, | |
lwd = 2, | |
col = "blue") | |
lines(y_upper_mean[predator_idx] ~ times, | |
lty = 2, | |
col = "blue") | |
lines(y_lower_mean[predator_idx] ~ times, | |
lty = 2, | |
col = "blue") | |
# plot the observed points over the top | |
points(y_obs[, 1] ~ times) | |
points(y_obs[, 2] ~ times, col = "blue") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is now implemented on the dev branch of greta.dynamics: https://github.com/greta-dev/greta.dynamics/tree/dev