Skip to content

Instantly share code, notes, and snippets.

@boennecd
Created January 22, 2020 08:39
Show Gist options
  • Select an option

  • Save boennecd/7bfd5a3715edc6b5fedf6bab6bfeaf0e to your computer and use it in GitHub Desktop.

Select an option

Save boennecd/7bfd5a3715edc6b5fedf6bab6bfeaf0e to your computer and use it in GitHub Desktop.
# probit-issue-simulate.R
library(Rcpp)
# 'Weibull_mix.cpp' is at
# https://gist.github.com/boennecd/770cd7db6f0fa1b6829311ef005436b7
Rcpp::sourceCpp("Weibull_mix.cpp")
# Simulate potentially right censored outcomes from a generalized survival
# model with a given link function and a mixed weibull distribution as the
# baseline.
#
# Args:
# b: coefficients for the fixed effects.
# X: design matrix for the fixed effects.
# Z: design matrix for the random effects.
# cvmat: covariance matrix for the random effects.
# grp: integer or factor vector with group identifiers.
# ps: mixture probabilities.
# gs: mixture shape parameters.
# ls: mixture scale parameters.
# cens_func: function to simulate censoring time given integer with number
# of observations.
# link: character specifying the link function.
#
# Returns:
# The observed observation times, the event indicators, and the drawn
# random effects.
sim_wei_mix_base <- function(
b, X, Z, cvmat, grp, ps = c(.5, .5), gs = c(1.5, .5), ls = c(1, 1),
cens_func = function(n_obs) runif(n_obs, 1, 5),
link = c("PH", "PO", "probit")){
if(!require(rstpm2))
stop(sprintf("Requires %s", sQuote("rstpm2")))
n_obs <- NROW(X)
n_fix <- NCOL(X)
n_rng <- NCOL(Z)
n_mix <- length(gs)
link <- link[1]
# check args
chol_cv <- chol(cvmat)
stopifnot(
length(ps) == n_mix, all(ps > 0),
all(gs > 0),
length(ls) == n_mix, all(ls > 0),
length(b) == n_fix, is.numeric(b),
all(dim(as.matrix(cvmat)) == c(n_rng, n_rng)),
length(grp) == n_obs, is.integer(grp) || is.factor(grp),
NROW(Z) == n_obs , is.numeric(Z),
is.numeric(X),
is.character(link), link %in% c("PH", "PO", "probit"))
# make transformations and assign variables used later
if(is.factor(grp))
grp <- as.integer(grp)
n_grp <- length(unique(grp))
# assign survial function
ps <- ps / sum(ps)
S0 <- Surv_Weibull_mix
formals(S0)[c("ps", "gs", "ls")] <- list(ps, gs, ls)
# simulate linear predictor
rngs <- matrix(rnorm(n_grp * n_rng), nc = n_rng) %*% chol_cv
colnames(rngs) <- colnames(Z)
lp <- drop(X %*% b + rowSums(Z * rngs[grp, ]))
# simulate survival time and censoring time
# TODO: can we just simulate from a mixture?
unifs <- runif(n_obs, 0, 1)
root_func <- switch(
link,
PH = function(ti, u, lp) S0(ti) - u^exp(-lp),
PO = function(ti, u, lp) {
sb <- S0(ti)
f <- exp(lp / 2)
(1 - sb) / (sb + .Machine$double.eps) * f - (1 - u) / u / f
},
probit = function(ti, u, lp){
sb <- S0(ti)
eta_min <- qnorm(.Machine$double.eps)
trunc <- function(eta)
pmax(eta_min, pmin(-eta_min, eta))
lp_2 <- lp / 2
trunc(-qnorm(sb) + lp_2) + trunc(lp_2 + qnorm(u))
})
suppressWarnings(y <- vuniroot(
root_func, lower = rep(0, n_obs), upper =
rep(.Machine$double.xmax^(1/16), n_obs),
tol = .Machine$double.eps^(3/4), u = unifs, lp = lp))
y <- y$root
y <- pmax(.Machine$double.eps^(1/2), y)
cens <- cens_func(n_obs)
stopifnot(length(cens) == n_obs, all(cens > 0))
event <- y < cens
y <- pmin(y, cens)
list(y = y, event = event, rngs = rngs)
}
# test the function
if(TRUE)
(function(){
# reset seed after test
if (!exists(".Random.seed", envir = .GlobalEnv, inherits = FALSE))
runif(1)
old_seed <- .GlobalEnv$.Random.seed
on.exit(.GlobalEnv$.Random.seed <- old_seed)
#####
# simulate data
get_dat <- function(link){
dat <- within(list(), {
n_p_grp <- 2L
n_grp <- 5L
n_obs <- n_p_grp * n_grp
grp <- as.integer(gl(n_grp, n_p_grp))
X <- cbind(`(Intercept)` = 1, treatment = runif(n_obs) > .5,
x = rnorm(n_grp)[grp])
Z <- cbind(`(Intercept)` = 1, x = X[, "x"])
beta <- c(.5, 1/4, 1/4)
stds <- c(.5, 1/4)
rho <- .5 * prod(stds)
cvmat <- matrix(c(stds[1]^2, rho, rho, stds[2]^2), nc = 2)
})
with(dat, sim_wei_mix_base(
b = beta, X = X, Z = Z, cvmat = cvmat, grp = grp))
}
set.seed(1)
sim_dat <- get_dat("PH")
stopifnot(isTRUE(all.equal(
sim_dat,
list(y = c(
0.0261795159458619, 0.113360951088831, 0.0277792241363675,
0.138872464541846, 0.369658829436196, 0.0531238909295669, 2.265086828731,
0.969466868448416, 0.0391259321937663, 0.0843747507945069),
event = c(TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, FALSE, TRUE, TRUE, TRUE),
rngs = structure(c(
0.755890584225424, 0.194921618205716, -0.310620290270902, -1.10734994358875, 0.562465459071554,
0.179244234333629, 0.0452251097670186, 0.126691461298552, -0.0990378816268971,
0.269199772646765), .Dim = c(5L, 2L),
.Dimnames = list(NULL, c("(Intercept)", "x")))))))
set.seed(1)
sim_dat <- get_dat("PO")
stopifnot(isTRUE(all.equal(
sim_dat,
list(
y = c(0.0261795159458619, 0.113360951088831, 0.0277792241363675,
0.138872464541846, 0.369658829436196, 0.0531238909295669, 2.265086828731,
0.969466868448416, 0.0391259321937663, 0.0843747507945069),
event = c(TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, FALSE, TRUE, TRUE, TRUE),
rngs = structure(
c(0.755890584225424, 0.194921618205716, -0.310620290270902, -1.10734994358875, 0.562465459071554,
0.179244234333629, 0.0452251097670186, 0.126691461298552, -0.0990378816268971, 0.269199772646765),
.Dim = c(5L, 2L), .Dimnames = list(NULL, c("(Intercept)", "x")))))))
set.seed(1)
sim_dat <- get_dat("probit")
stopifnot(isTRUE(all.equal(
sim_dat,
list(
y = c(0.0261795159458619, 0.113360951088831, 0.0277792241363675,
0.138872464541846, 0.369658829436196, 0.0531238909295669, 2.265086828731,
0.969466868448416, 0.0391259321937663, 0.0843747507945069),
event = c(TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, FALSE, TRUE, TRUE, TRUE),
rngs = structure(
c(0.755890584225424, 0.194921618205716, -0.310620290270902, -1.10734994358875, 0.562465459071554,
0.179244234333629, 0.0452251097670186, 0.126691461298552, -0.0990378816268971,
0.269199772646765),
.Dim = c(5L, 2L), .Dimnames = list(NULL, c("(Intercept)", "x")))))))
})()
invisible()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment