Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save seabbs/027cd1c439e8acf1d598cc03ef33aaa4 to your computer and use it in GitHub Desktop.
Save seabbs/027cd1c439e8acf1d598cc03ef33aaa4 to your computer and use it in GitHub Desktop.
This gist shows how to estimate a doubly censored (i.e daily data) and right truncated (i.e due to epidemic phase) distribution using the brms package.
# Load packages
library(brms)
library(cmdstanr)
library(data.table) # here we use the development version of data.table install it with data.table::update_dev_pkg
library(purrr)
# Set up parallel cores
options(mc.cores = 4)
# Simulate some truncated and truncation data
init_cases <- 100
growth_rate <- 0.1
max_t <- 20
samples <- 400
# note we actually won't end up with this many samples as some will be truncated
logmean <- 1.6
logsd <- 0.6
# Simulate the underlying outbreak structure assuming exponential growth
cases <- data.table(
cases = (init_cases * exp(growth_rate * (1:max_t))) |>
map_dbl(~ rpois(1, .)),
time = 1:max_t
)
plot(cases$cases)
# Make a case line list
linelist <- cases |>
DT(, .(id = 1:cases), by = time)
# Simulate the observation process for the line list
obs <- data.table(
time = sample(linelist$time, samples, replace = FALSE),
delay = rlnorm(samples, logmean, logsd)
) |>
# Add a new ID
DT(, id := 1:.N) |>
# When would data be observed
DT(, obs_delay := time + delay) |>
# Integerise delay
DT(, daily_delay := floor(delay)) |>
# Day after observations
DT(, day_after_delay := ceiling(delay)) |>
# Time observe for
DT(, obs_time := max_t - time) |>
# We don't know this exactly so need to censor
# Set to the midday point as average across day
DT(, censored_obs_time := obs_time - 0.5) |>
DT(, censored := "interval")
# Make event based data for latent modelling
obs <- obs |>
DT(, primary_event := floor(time)) |>
DT(, secondary_event := floor(obs_delay)) |>
DT(, max_t := max_t)
# Truncate observations
truncated_obs <- obs |>
DT(obs_delay <= max_t)
double_truncated_obs <- truncated_obs |>
# The lognormal family in brms does not support 0 so also truncate delays > 1
# This seems like it could be improved
DT(daily_delay >= 1)
# Fit lognormal model with no corrections
naive_model <- brm(
bf(daily_delay ~ 1, sigma ~ 1), data = double_truncated_obs,
family = lognormal(), backend = "cmdstanr", adapt_delta = 0.9
)
# We see that the log mean is truncated
# the sigma_intercept needs to be exponentiated to return the log sd
summary(naive_model)
# Adjust for truncation
trunc_model <- brm(
bf(daily_delay | trunc(lb = 1, ub = censored_obs_time) ~ 1, sigma ~ 1),
data = double_truncated_obs, family = lognormal(),
backend = "cmdstanr", adapt_delta = 0.9
)
# Getting closer to recovering our simulated estimates
summary(trunc_model)
# Correct for censoring
censor_model <- brm(
bf(daily_delay | cens(censored, day_after_delay) ~ 1, sigma ~ 1),
data = double_truncated_obs, family = lognormal(),
backend = "cmdstanr", adapt_delta = 0.9
)
# Less close than truncation but better than naive model
summary(censor_model)
# Correct for double interval censoring and truncation
censor_trunc_model <- brm(
bf(
daily_delay | trunc(lb = 1, ub = censored_obs_time) +
cens(censored, day_after_delay) ~ 1,
sigma ~ 1
),
data = double_truncated_obs, family = lognormal(), backend = "cmdstanr"
)
# Recover underlying distribution
# As the growth rate increases and with short delays we may still see a bias
# as we have a censored observation time
summary(censor_trunc_model)
# Model censoring as a latent process (WIP)
# For this model we need to use a custom brms family and so
# the code is significantly more complex.
# Custom family for latent censoring and truncation
fit_latent_lognormal <- function(fn = brm, ...) {
latent_lognormal <- custom_family(
"latent_lognormal",
dpars = c("mu", "sigma", "pwindow", "swindow"),
links = c("identity", "log", "identity", "identity"),
lb = c(NA, 0, 0, 0),
ub = c(NA, NA, 1, 1),
type = "real",
vars = c("vreal1[n]", "vreal2[n]")
)
stan_funs <- "
real latent_lognormal_lpdf(real y, real mu, real sigma, real pwindow,
real swindow, real sevent,
real end_t) {
real p = y + pwindow;
real s = sevent + swindow;
real d = s - p;
real obs_time = end_t - p;
return lognormal_lpdf(d | mu, sigma) - lognormal_lcdf(obs_time | mu, sigma);
}
"
stanvars <- stanvar(block = "functions", scode = stan_funs)
# Set up shared priors ----------------------------------------------------
priors <- c(
prior(uniform(0, 1), class = "b", dpar = "pwindow", lb = 0, ub = 1),
prior(uniform(0, 1), class = "b", dpar = "swindow", lb = 0, ub = 1)
)
fit <- fn(family = latent_lognormal, stanvars = stanvars, prior = priors, ...)
return(fit)
}
# Fit latent lognormal model
latent_model <- fit_latent_lognormal(
bf(primary_event | vreal(secondary_event, max_t) ~ 1, sigma ~ 1,
pwindow ~ 0 + as.factor(id), swindow ~ 0 + as.factor(id)),
data = truncated_obs, backend = "cmdstanr", fn = brm,
adapt_delta = 0.95
)
# Should also see parameter recovery using this method though
# run-times are much higher and the model is somewhat unstable.
summary(latent_model)
@parksw3
Copy link

parksw3 commented Oct 27, 2022

I merged both branches so you can also look at the main branch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment