Skip to content

Instantly share code, notes, and snippets.

@goldingn
Last active October 2, 2018 12:01
Show Gist options
  • Save goldingn/bea9e3dac12cbcb53e4e3fa8456d77bc to your computer and use it in GitHub Desktop.
Save goldingn/bea9e3dac12cbcb53e4e3fa8456d77bc to your computer and use it in GitHub Desktop.
A prototype and demonstration of solving ODEs with greta
# prototype ODE solver function for greta
# user-facing function to export:
# derivative must be a function with the first two arguments being 'y' and 't',
# and subsequent named arguments representing (temporally static) model
# parameters
# y0 must be a greta array representing the shape of y at time 0
# times must be a column vector of times at which to evaluate y
# dots must be named greta arrays for the additional (fixed) parameters
ode_solve <- function (derivative, y0, times, ...) {
times <- as.greta_array(times)
# check times is a column vector
t_dim <- dim(times)
if (length(t_dim != 2) && t_dim[2] != 1) {
stop("",
call. = FALSE)
}
dots <- list(...)
dots <- lapply(dots, as.greta_array)
# check all additional parameters are named and there are no extras
# check derivative is a function
# check the arguments of derivative are valid and match dots
# create a tensorflow version of the function
tf_derivative <- as_tf_derivative(derivative, y0, times, dots)
# the dimensions should be the dimensions of the y0, duplicated along times
n_time <- dim(times)[1]
y0_dims <- dim(y0)
# drop the first element if it's a one
if (y0_dims[1] == 1)
y0_dims <- y0_dims[-1]
dims <- c(n_time, y0_dims)
op("ode", y0, times, ...,
dim = dims,
tf_operation = "tf_ode_solve",
operation_args = list(tf_derivative = tf_derivative))
}
# internal tf function wrapping the core TF method
# return a tensor for the integral of derivative evaluated at times, given
# starting state y0 and other parameters dots
tf_ode_solve <- function(y0, times, ..., tf_derivative) {
# drop the columns and batch dimension in times
times <- tf_flatten(times)
times <- tf$slice(times, c(0L, 0L), c(1L, -1L))
times <- tf$squeeze(times, 0L)
# assign the dots (as tensors) to the function environment
assign("tf_dots", list(...),
environment(tf_derivative))
# integrate - need to run this with the dag's TF graph as default, so that
# tf_derivative creates tensors correctly
dag <- parent.frame()$dag
dag$on_graph(
integral <- tf$contrib$integrate$odeint(tf_derivative, y0, times)
)
# reshape to put batch dimension first
permutation <- seq_along(dim(integral)) - 1L
permutation[1:2] <- permutation[2:1]
integral <- tf$transpose(integral, perm = permutation)
# if the first (non-batch) dimension of y0 was 1, drop it in the results
if (dim(y0)[[2]] == 1)
integral <- tf$squeeze(integral, 2L)
integral
}
# given a greta/R function derivative function, and greta arrays for the inputs,
# return a tensorflow function taking tensors for y and t and returning a tensor
# for dydt
as_tf_derivative <- function (derivative, y, t, dots) {
# create a function acting on the full set of inputs, as tensors
args <- list(r_fun = derivative, y = y, t = t)
tf_fun <- do.call(as_tf_function, c(args, dots))
# for CRAN's benefit
tf_dots <- NULL
# return a function acting only on tensors y and t, to feed to the ode solver
function (y, t) {
# tf_dots will have been added to this environment by tf_ode_solve
args <- list(y = y, t = t)
do.call(tf_fun, c(args, tf_dots))
}
}
# ~~~~~~~~
# test run
library(greta)
# load required internal methods
as.greta_array <- .internals$greta_arrays$as.greta_array
op <- .internals$nodes$constructors$op
tf_flatten <- greta:::tf_flatten
as_tf_function <- .internals$utils$greta_array_operations$as_tf_function
library(tensorflow)
# simulate data using the Lotka-Volterra example from deSolve
set.seed(2018-03-14)
library (deSolve)
LVmod <- function(Time, State, Pars) {
with(as.list(c(State, Pars)), {
Ingestion <- rIng * Prey * Predator
GrowthPrey <- rGrow * Prey * (1 - Prey / K)
MortPredator <- rMort * Predator
dPrey <- GrowthPrey - Ingestion
dPredator <- Ingestion * assEff - MortPredator
return (list(c(dPrey, dPredator)))
})
}
pars <- c(rIng = 0.2, # /day, rate of ingestion
rGrow = 1.0, # /day, growth rate of prey
rMort = 0.2, # /day, mortality rate of predator
assEff = 0.5, # -, assimilation efficiency
K = 10) # mmol/m3, carrying capacity
yini <- c(Prey = 1, Predator = 2)
times <- seq(0, 50, by = 1)
out <- ode(yini, times, LVmod, pars)
# simulate observations
jitter <- rnorm(2 * length(times), 0, 0.1)
y_obs <- out[, -1] + matrix(jitter, ncol = 2)
# ~~~~~~~~~
# fit a greta model to infer the parameters from this simulated data
# greta version of the function
lotka_volterra <- function(y, t, rIng, rGrow, rMort, assEff, K) {
Prey <- y[1, 1]
Predator <- y[1, 2]
Ingestion <- rIng * Prey * Predator
GrowthPrey <- rGrow * Prey * (1 - Prey / K)
MortPredator <- rMort * Predator
dPrey <- GrowthPrey - Ingestion
dPredator <- Ingestion * assEff - MortPredator
cbind(dPrey, dPredator)
}
# priors for the parameters
rIng <- uniform(0, 0.4) # /day, rate of ingestion
rGrow <- uniform(0.8, 1.2) # /day, growth rate of prey
rMort <- uniform(0, 0.4) # /day, mortality rate of predator
assEff <- uniform(0.25, 0.75) # -, assimilation efficiency
K <- uniform(8, 12) # mmol/m3, carrying capacity
# initial values and observation error
y0 <- uniform(0.5, 1.5, dim = c(1, 2)) + t(0:1)
obs_sd <- uniform(0, 0.5)
# solution to the ODE
y <- ode_solve(lotka_volterra, y0, times, rIng, rGrow, rMort, assEff, K)
# sampling statement/observation model
distribution(y_obs) <- normal(y, obs_sd)
# we can use greta to solve directly, for a fixed set of parameters (the true
# ones in this case)
values <- c(list(y0 = t(1:2)),
as.list(pars))
vals <- calculate(y, values)
plot(vals[, 1] ~ times, type = "l", ylim = range(vals))
lines(vals[, 2] ~ times, lty = 2)
points(y_obs[, 1] ~ times)
points(y_obs[, 2] ~ times, pch = 2)
# or we can do inference on the parameters:
# build the model (takes a few seconds to define the tensorflow graph)
m <- model(rIng, rGrow, rMort, assEff, K, obs_sd)
# run MCMC on it - warning: this takes a lot of CPU-time, especialy with HMC!!
draws <- mcmc(m, sampler = rwmh())
plot(draws)
# we can get predictive posteriors for the solution, even after sampling the other
# parameters
y_upper <- y + 1.96 * obs_sd
y_lower <- y - 1.96 * obs_sd
y_upper_draws <- calculate(y_upper, draws)
y_lower_draws <- calculate(y_lower, draws)
y_draws <- calculate(y, draws)
y_upper_mean <- summary(y_upper_draws)$statistics[, "Mean"]
y_lower_mean <- summary(y_lower_draws)$statistics[, "Mean"]
y_mean <- summary(y_draws)$statistics[, "Mean"]
# plot the median predictions and intervals
prey_idx <- seq_along(times)
predator_idx <- prey_idx + length(prey_idx)
plot(y_mean[prey_idx] ~ times,
type = "l",
lwd = 2,
ylim = range(quants))
lines(y_upper_mean[prey_idx] ~ times,
lty = 2)
lines(y_lower_mean[prey_idx] ~ times,
lty = 2)
lines(y_mean[predator_idx] ~ times,
lwd = 2,
col = "blue")
lines(y_upper_mean[predator_idx] ~ times,
lty = 2,
col = "blue")
lines(y_lower_mean[predator_idx] ~ times,
lty = 2,
col = "blue")
# plot the observed points over the top
points(y_obs[, 1] ~ times)
points(y_obs[, 2] ~ times, col = "blue")
@goldingn
Copy link
Author

goldingn commented Oct 2, 2018

This is now implemented on the dev branch of greta.dynamics: https://github.com/greta-dev/greta.dynamics/tree/dev

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment