Forked from seabbs/interveral-censored-right-truncated-distribution-estimation-with-brms.R
Last active
October 26, 2022 15:41
-
-
Save sbfnk/569ad82641d73286a355317e53d911cf to your computer and use it in GitHub Desktop.
This gist shows how to estimate a doubly censored (i.e daily data) and right truncated (i.e due to epidemic phase) distribution using the brms package.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load packages | |
library(brms) | |
library(cmdstanr) | |
library(data.table) # here we use the development version of data.table install it with data.table::update_dev_pkg | |
library(purrr) | |
library(bpmodels) # devtools::install_github("sbfnk/bpmodels") | |
# Set up parallel cores | |
options(mc.cores = 4) | |
# Simulate some truncated and truncation data | |
init_cases <- 500 ## 500 cases at time 0 (which is not the same as 500 daily cases in an ongoing epidemic) | |
growth_rate <- 0.1 | |
max_t <- 20 | |
samples <- 400 | |
# note we actually won't end up with this many samples as some will be truncated | |
logmean <- 1.6 | |
logsd <- 0.6 | |
# Simulate the underlying outbreak structure assuming a gamma distributed generation time | |
## reproduction number | |
R <- 1.5 | |
## serial interval CV | |
tg_kappa <- 0.5 | |
## serial interval | |
tg_mean <- (exp(tg_kappa * log(R)) - 1) / (tg_kappa * growth_rate) | |
## gamma distribution parameters | |
tg_shape <- 1 / tg_kappa | |
tg_scale <- tg_mean / tg_shape | |
tg <- function(n) { | |
rgamma(n, shape = tg_shape, scale = tg_scale) | |
} | |
linelist <- bpmodels::chain_sim( | |
init_cases, "pois", tf = max_t, serial = tg, lambda = R | |
) |> | |
data.table() |> | |
DT(time > 0, list(time, id = 1:.N)) | |
cases <- linelist |> | |
DT(, .(time = as.integer(floor(time)))) |> | |
DT(, .(cases = .N), by = time) |> | |
setorder(time) | |
plot(cases$cases) | |
# Simulate the observation process for the line list | |
obs <- data.table( | |
time = sample(linelist$time, samples, replace = FALSE), | |
delay = rlnorm(samples, logmean, logsd) | |
) |> | |
# Add a new ID | |
DT(, id := 1:.N) |> | |
# When would data be observed | |
DT(, obs_delay := time + delay) |> | |
# Integerise delay | |
DT(, daily_delay := floor(obs_delay) - floor(time)) |> | |
# Day before observations | |
DT(, daily_delay_m1 := pmax(daily_delay - 1, 0)) |> | |
# Day after observations | |
DT(, daily_delay_p1 := daily_delay + 1) |> | |
# Time observe for (including last day) | |
DT(, obs_time := max_t + 1 - floor(time)) |> | |
# We don't know this exactly so need to censor | |
# Set to the midday point as average across day | |
DT(, censored_obs_time := obs_time - 0.5) |> | |
DT(, censored := "interval") | |
# Make event based data for latent modelling | |
obs <- obs |> | |
DT(, ptime := floor(time)) |> | |
DT(, stime := floor(obs_delay)) |> | |
DT(, max_t := max_t) | |
# Truncate observations | |
truncated_obs <- obs |> | |
DT(stime <= max_t) | |
double_truncated_obs <- truncated_obs |> | |
# The lognormal family in brms does not support 0 so also truncate delays > 1 | |
# This seems like it could be improved | |
DT(daily_delay_m1 >= 1) | |
# Fit lognormal model with no corrections | |
naive_model <- brm( | |
bf(daily_delay ~ 1, sigma ~ 1), data = truncated_obs, | |
family = lognormal(), backend = "cmdstanr", adapt_delta = 0.9 | |
) | |
# We see that the log mean is truncated | |
# the sigma_intercept needs to be exponentiated to return the log sd | |
summary(naive_model) | |
double_truncated_obs_trunc <- truncated_obs |> | |
# The lognormal family in brms does not support 0 so also truncate delays > 1 | |
# This seems like it could be improved | |
DT(censored_obs_time >= 1) | |
# Adjust for truncation | |
trunc_model <- brm( | |
bf(daily_delay | trunc(lb = 1, ub = censored_obs_time) ~ 1, sigma ~ 1), | |
data = double_truncated_obs_trunc, family = lognormal(), | |
backend = "cmdstanr", adapt_delta = 0.9 | |
) | |
# Getting closer to recovering our simulated estimates | |
summary(trunc_model) | |
# Correct for censoring | |
censor_model <- brm( | |
bf(daily_delay_m1 | cens(censored, daily_delay_p1) ~ 1, sigma ~ 1), | |
data = double_truncated_obs, family = lognormal(), | |
backend = "cmdstanr", adapt_delta = 0.9 | |
) | |
# Less close than truncation but better than naive model | |
summary(censor_model) | |
# Correct for double interval censoring and truncation | |
censor_trunc_model <- brm( | |
bf( | |
daily_delay_m1 | trunc(lb = 1, ub = censored_obs_time) + | |
cens(censored, daily_delay_p1) ~ 1, | |
sigma ~ 1 | |
), | |
data = double_truncated_obs, family = lognormal(), backend = "cmdstanr" | |
) | |
# Recover underlying distribution | |
# As the growth rate increases and with short delays we may still see a bias | |
# as we have a censored observation time | |
summary(censor_trunc_model) | |
# Model censoring as a latent process (WIP) | |
# For this model we need to use a custom brms family and so | |
# the code is significantly more complex. | |
# Custom family for latent censoring and truncation | |
fit_latent_lognormal <- function(fn = brm, ...) { | |
latent_lognormal <- custom_family( | |
"latent_lognormal", | |
dpars = c("mu", "sigma", "pwindow", "swindow"), | |
links = c("identity", "log", "identity", "identity"), | |
lb = c(NA, 0, 0, 0), | |
ub = c(NA, NA, 1, 1), | |
type = "real", | |
vars = c("vreal1[n]", "vreal2[n]") | |
) | |
stan_funs <- " | |
real latent_lognormal_lpdf(real y, real mu, real sigma, real pwindow, | |
real swindow, real sevent, | |
real end_t) { | |
real p = y + pwindow; | |
real s = sevent + swindow; | |
real d = s - p; | |
real obs_time = end_t - p; | |
return lognormal_lpdf(d | mu, sigma) - lognormal_lcdf(obs_time | mu, sigma); | |
} | |
" | |
stanvars <- stanvar(block = "functions", scode = stan_funs) | |
# Set up shared priors ---------------------------------------------------- | |
priors <- c( | |
prior(uniform(0, 1), class = "b", dpar = "pwindow", lb = 0, ub = 1), | |
prior(uniform(0, 1), class = "b", dpar = "swindow", lb = 0, ub = 1) | |
) | |
fit <- fn(family = latent_lognormal, stanvars = stanvars, prior = priors, ...) | |
return(fit) | |
} | |
# Fit latent lognormal model | |
latent_model <- fit_latent_lognormal( | |
bf(primary_event | vreal(secondary_event, max_t) ~ 1, sigma ~ 1, | |
pwindow ~ 0 + as.factor(id), swindow ~ 0 + as.factor(id)), | |
data = truncated_obs, backend = "cmdstanr", fn = brm, | |
adapt_delta = 0.95 | |
) | |
# Should also see parameter recovery using this method though | |
# run-times are much higher and the model is somewhat unstable. | |
summary(latent_model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment