Skip to content

Instantly share code, notes, and snippets.

Last active June 3, 2024 05:03
Show Gist options
  • Save apoorvalal/64b05dbe0a1d4f72ebafb491bd282ffe to your computer and use it in GitHub Desktop.
Save apoorvalal/64b05dbe0a1d4f72ebafb491bd282ffe to your computer and use it in GitHub Desktop.
Inference for the Population Average Treatment Effect using fully interacted OLS
library(momentfit); library(car); library(tictoc)
# %%
dgp = \(n=500, k = 2){
X = matrix(rnorm(n * 2), n, 2)
Y1 = X[, 1] + X[, 1]^2 + runif(n, -0.5, 0.5)
Y0 = X[, 2] + X[, 2]^2 + runif(n, -1, 1)
Z = rbinom(n, 1, 0.6)
Y = Z * Y1 + (1-Z) * Y0
data.frame(Y, Z, X)
df = dgp()
# %%
linestimator = function(Z, Y, X) {
X = scale(X); n = dim(X)[1]; p = dim(X)[2]
linreg = lm(Y ~ Z * X)
est = coef(linreg)[2]; vehw = hccm(linreg)[2, 2]
## super population correction
inter = coef(linreg)[(p + 3):(2 * p + 2)]
vsuper = vehw + sum(inter * (cov(X) %*% inter)) / n
c(est, sqrt(vehw), sqrt(vsuper)) %>% setNames(c("OLS Est", "EHW SE", "SuperPop SE"))
# %% estimate problem jointly via gmm
moment_cond = function(theta, df) {
p = dim(df)[2]-2
# unpack params: μ_x , β
μx = theta[1:p]; β = theta[(p+1):length(theta)]
# unpack data
y = as.numeric(df[, 1]); z = as.numeric(df[, 2]); x = as.matrix(df[, 3:ncol(df)])
xcent = sweep(x, 2, μx) # X - μx
xtilde = cbind(z, 1, x, xcent)
m = xtilde * as.vector(y -(xtilde %*% β)) # ols moment condition x(y - xβ)
cbind(xcent, m) # stack moment conditions - [μx, β]
linestimator(df$Z, df$Y, df[,3:ncol(df)])
# gmm
m = gel4(g = moment_cond, x = df, theta0 = runif(8), coefSlv = "optim", gelType = "ET")
summary(m)@coef[3,1:2] %>% setNames(c("GMM Est", "GMM SE")) # 3rd coef is treatment effect
# OLS Est-0.328529825912218EHW SE0.168275957745803SuperPop SE0.177575093923765
# GMM Est-0.327120815478441GMM SE0.164029681609099
# %% repeat
res = LalRUtils::mcReplicate(5000, {
df = dgp()
linres = linestimator(df$Z, df$Y, df[,3:ncol(df)])
m = gel4(g = moment_cond, x = df, theta0 = runif(8), coefSlv = "optim", gelType = "ET")
gmmres = summary(m)@coef[3,1:2] # coef and SE from GMM estimation
c(linres, gmmres)
}, mc.cores = 12)
# %% ## bias
mean(res[1, ]) # ols
# 0.00750274551064356
mean(res[4, ]) # gmm
# 0.00436392500046487 # lower bias
# %% coverage based on EHW standard error
mean((res[1, ] - 1.96 * res[2, ]) * (res[1, ] + 1.96 * res[2, ]) <= 0)
# 0.9268 - under-coverage
# %% coverage based on population standard error
mean((res[1, ] - 1.96 * res[3, ]) * (res[1, ] + 1.96 * res[3, ]) <= 0)
# 0.9548 - correct coverage but conservative
# %% coverage based on gmm standard error
mean((res[4, ] - 1.96 * res[5, ]) * (res[4, ] + 1.96 * res[5, ]) <= 0)
# 0.952 - correct coverage, less conservative
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment