Created
May 6, 2020 17:48
-
-
Save herbps10/a7c03b1898a02c9b187a17bcc007d763 to your computer and use it in GitHub Desktop.
Optimal treatment effects with resource constraints
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(tidyverse) | |
library(SuperLearner) | |
# | |
# Data generating process | |
# | |
generate_data <- function(seed, N) { | |
set.seed(seed) | |
tibble( | |
W1 = rnorm(N), | |
W2 = rnorm(N), | |
W3 = rnorm(N), | |
W4 = rnorm(N), | |
A = rbinom(N, 1, 0.5), | |
H = rbinom(N, 1, 0.5), | |
Y = ifelse(H == 0, | |
rbinom(N, 1, boot::inv.logit(1 - W1^2 + 3*W2 + A * (5 * W3^2 - 4.45))), | |
rbinom(N, 1, boot::inv.logit(-0.5 - W3 + 2 * W1 * W2 + A * (3 * abs(W2) - 1.5))) | |
) | |
) | |
} | |
# | |
# Plugin, one-step, and TMLE estimators | |
# | |
estimators <- function(data, kappa, SL.library) { | |
mu_model <- SuperLearner(data$Y, select(data, A, W1, W2, W3, W4), SL.library = SL.library, family = "binomial") | |
mu_n_a_w <- predict(mu_model)$pred | |
mu_n_1 <- predict(mu_model, newdata = select(data, W1, W2, W3, W4) %>% mutate(A = 1))$pred | |
mu_n_0 <- predict(mu_model, newdata = select(data, W1, W2, W3, W4) %>% mutate(A = 0))$pred | |
Y_tilde <- (2*data$A - 1) / (0.5) * (data$Y - mean(data$Y)) + mean(data$Y) | |
qbar_model <- SuperLearner(Y_tilde, select(data, W1, W2, W3, W4), SL.library = SL.library, family = "gaussian") | |
qbar_n <- predict(qbar_model)$pred | |
tau_n <- max(0, quantile(qbar_n, 1 - kappa)) | |
d_n <- as.numeric(qbar_n > tau_n) | |
# | |
# Plug-in estimator | |
# | |
newdata = select(data, starts_with("W")) %>% | |
mutate(A = d_n) | |
psi_plugin <- mean(predict(mu_model, newdata)$pred) | |
# | |
# One-step estimator | |
# | |
# estimate of influence function | |
if_n <- as.numeric(data$A == d_n) / 0.5 * (data$Y - mu_n_a_w) + | |
mu_n_a_w - tau_n * (mean(d_n) - kappa) - psi_plugin | |
# add the empirical mean of the influence function to the plugin estimator | |
psi_onestep <- psi_plugin + mean(if_n) | |
psi_onestep_ci <- psi_onestep + c(-1, 1) * qnorm(0.975) * sd(if_n) / sqrt(length(if_n)) | |
# | |
# TMLE | |
# | |
# clever covariate | |
H_a_w <- as.numeric(data$A == d_n) / 0.5 | |
H_0 <- as.numeric(0 == d_n) / 0.5 | |
H_1 <- as.numeric(1 == d_n) / 0.5 | |
offset <- boot::logit(mu_n_a_w) | |
logistic_regression <- glm(data$Y ~ -1 + H_a_w + offset(offset), family = binomial(link = "logit")) | |
epsilon_n <- coef(logistic_regression)[1] | |
mu_n_a_w_epsilon <- boot::inv.logit(boot::logit(mu_n_a_w) + epsilon_n * H_a_w) | |
mu_n_0_epsilon <- boot::inv.logit(boot::logit(mu_n_0) + epsilon_n * H_0) | |
mu_n_1_epsilon <- boot::inv.logit(boot::logit(mu_n_1) + epsilon_n * H_1) | |
psi_tmle <- mean(mu_n_0_epsilon * as.numeric(d_n == 0) + mu_n_1_epsilon * as.numeric(d_n == 1)) | |
if_n_epsilon <- as.numeric(data$A == d_n) / 0.5 * (data$Y - mu_n_a_w_epsilon) + | |
mu_n_a_w_epsilon - tau_n * (mean(d_n) - kappa) - psi_tmle | |
psi_tmle_ci <- psi_tmle + c(-1, 1) * qnorm(0.975) * sd(if_n) / sqrt(length(if_n_epsilon)) | |
tribble( | |
~method, ~psi, ~ci, | |
"plugin", psi_plugin, NULL, | |
"onestep", psi_onestep, psi_onestep_ci, | |
"tmle", psi_tmle, psi_tmle_ci | |
) | |
} | |
# | |
# Execute simulation study | |
# | |
SL.library <- c("SL.glm", "SL.glm.interaction", "SL.step", "SL.mean", "SL.step.interaction", "SL.step.forward") | |
simulations <- 250 | |
psi0 <- 0.49 | |
Ns <- c(50, 500, 1000, 5000) | |
pb <- progress::progress_bar$new(total = simulations * length(Ns)) | |
simulation_study <- expand_grid( | |
seed = 1:simulations, | |
N = Ns | |
) %>% | |
mutate( | |
data = map2(seed, N, generate_data), | |
psi = map(data, function(data) { | |
pb$tick() | |
estimators(data, kappa, SL.library) | |
}) | |
) | |
# Calculate coverage | |
covered <- function(ci, x) ci[1] <= x && ci[2] >= x | |
simulation_study_coverage <- simulation_study %>% | |
select(-data) %>% | |
unnest(psi) %>% | |
mutate(covered = map_lgl(ci, covered, psi0)) %>% | |
filter(!is.na(covered)) %>% | |
group_by(N, method) %>% | |
summarize(coverage = mean(covered)) | |
# Coverage Plot (Figure 1) | |
ggplot(simulation_study_coverage, aes(x = factor(N), y = coverage, color = method)) + | |
geom_point(size = 2) + | |
geom_hline(yintercept = 0.95, lty = 2) + | |
labs(x = "N", y = "95% CI coverage") + | |
cowplot::theme_cowplot() | |
# Mean absolute bias and standard errors (Table 1) | |
simulation_study %>% select(-data) %>% unnest(psi) %>% | |
group_by(N, method) %>% | |
summarize(abs_bias = mean(abs(psi0 - psi)), | |
se = sd(psi)) %>% | |
mutate_at(vars(abs_bias, se), signif, 2) %>% | |
knitr::kable(format = "latex") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment