Created
December 4, 2018 20:12
-
-
Save eveskew/cbce607e252638f5ebf082c634fa814d to your computer and use it in GitHub Desktop.
Discrete missing values in Stan when the number of categories > 2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Categorical missing data in Stan | |
# Demonstrated with three categories | |
# Code and ideas adapted from @rmcelreath | |
# https://gist.github.com/rmcelreath/9406643583a8c99304e459e644762f82 | |
# Code and ideas discussed with @dmontecino | |
# https://gist.github.com/dmontecino/b804853e4b36a57990a7108a35201cf5 | |
library(rstan) | |
library(rethinking) # devtools::install_github("rmcelreath/rethinking") | |
#============================================================================== | |
# Data setup | |
# Note: in this simulated data, categories are unbalanced and the data is | |
# missing at random | |
N <- 5000 | |
N_missing <- 100 | |
K <- 3 # number of categories | |
# Generate unordered categorical covariates, with unbalanced observations | |
# Category 1 = 50% of simulated data | |
# Category 2 = 30% of simulated data | |
# Category 3 = 20% of simulated data | |
x <- rmultinom(N, size = 1, prob = c(0.5, 0.3, 0.2)) | |
x <- sapply(1:N, function(i) which(x[ , i] == 1)) | |
# Simulate bivariate response as a function of the category | |
y <- rep(NA, N) | |
for (i in 1:N) { | |
if (x[i] == 1) | |
y[i] = rbinom(n = 1, size = 1, prob = 0.8) | |
else if (x[i] == 2) | |
y[i] = rbinom(n = 1, size = 1, prob = 0.3) | |
else # x = 3 | |
y[i] = rbinom(n = 1, size = 1, prob = 0.1) | |
} | |
# Simulate missing data | |
i_miss <- sample(1:N, size = N_missing) | |
x_obs <- x | |
x_obs[i_miss] <- (-1) # placeholder, Stan will not accept NA values | |
x_miss <- ifelse(1:N %in% i_miss, 1, 0) | |
# Create a covariate for use in prediction of the categorical variable | |
# Note: in this case, the covariate does not actually correlate with the | |
# category, but the code is such that this could be easily changed | |
cov_for_x <- NA | |
cov_for_x[x_obs == 1] <- | |
rbinom(length(cov_for_x[x_obs == 1]), size = 1, prob = 0.5) | |
cov_for_x[x_obs == 2] <- | |
rbinom(length(cov_for_x[x_obs == 2]), size = 1, prob = 0.5) | |
cov_for_x[x_obs == 3] <- | |
rbinom(length(cov_for_x[x_obs == 3]), size = 1, prob = 0.5) | |
cov_for_x[is.na(cov_for_x)] <- | |
rbinom(length(cov_for_x[is.na(cov_for_x)]), size = 1, prob = 0.5) | |
# Create dummy variables for use in the model | |
x_cat_1 <- ifelse(x_obs == 1, 1, 0) | |
x_cat_2 <- ifelse(x_obs == 2, 1, 0) | |
x_cat_3 <- ifelse(x_obs == 3, 1, 0) | |
cov_for_x_for_miss_cat <- cov_for_x[x_miss == 1] | |
# Examine distribution of full simulated data | |
simplehist(x) | |
# Examine distribution of simulated data that goes unobserved | |
simplehist(x[x_miss == 1]) | |
# So in this case, the categories are unbalanced, but the unobserved data | |
# reflects this same pattern | |
#============================================================================== | |
# Model definition | |
model <- " | |
data { | |
int N; // number of observations | |
int K; // number of categories | |
int y[N]; // binary outcome variable | |
int x_obs[N]; // observed categorical variable (-1 when unobserved) | |
int x_miss[N]; // dummy variable indicating missingness in the categorical variable | |
int cov_for_x[N]; // a binary predictor for the x categorical variable | |
int x_cat_1[N]; // dummy variable indicating when x has the first level | |
int x_cat_2[N]; // dummy variable indicating when x has the second level | |
int x_cat_3[N]; // dummy variable indicating when x has the third level | |
} | |
/////////////////////////////////////////////////////////////////////////////// | |
parameters { | |
real beta1; // coefficient for categorical variable level 1 | |
real beta2; // coefficient for categorical variable level 2 | |
real beta3; // coefficient for categorical variable level 3 | |
real a_cat2; // intercept to model the probability of category 2 | |
real a_cat3; // intercept to model the probability of category 3 | |
real b_cat2; // coefficient to model the probability of category 2 | |
real b_cat3; // coefficient to model the probability of category 3 | |
} | |
/////////////////////////////////////////////////////////////////////////////// | |
model { | |
// priors for all parameters | |
beta1 ~ normal(0, 1); | |
beta2 ~ normal(0, 1); | |
beta3 ~ normal(0, 1); | |
a_cat2 ~ normal(0, 1); | |
a_cat3 ~ normal(0, 1); | |
b_cat2 ~ normal(0, 1); | |
b_cat3 ~ normal(0, 1); | |
// likelihood | |
for (i in 1:N) { | |
vector[K] score; // vector containing log-odds of being in category 1, 2, or 3 | |
score[1] = 0; // score[1] is fixed at zero | |
score[2] = a_cat2 + b_cat2*cov_for_x[i]; // modeling the log-odds of category 2 | |
score[3] = a_cat3 + b_cat3*cov_for_x[i]; // modeling the log-odds of category 3 | |
if (x_miss[i] == 1) { // x is unobserved | |
// model the binary outcome, marginalizing over missingness | |
vector[K] logPxy; // vector to hold the log probabilities for each alternate scenario (category 1, 2, or 3) | |
// log_softmax is a vector of log probabilities for each x category | |
logPxy[1] = log_softmax(score)[1] + bernoulli_logit_lpmf(y[i] | beta1); // category 1 | |
logPxy[2] = log_softmax(score)[2] + bernoulli_logit_lpmf(y[i] | beta2); // category 2 | |
logPxy[3] = log_softmax(score)[3] + bernoulli_logit_lpmf(y[i] | beta3); // category 3 | |
target += log_sum_exp(logPxy); // sum log probabilities across the scenarios (i.e., marginalize over missingness) | |
} | |
else { // x is observed | |
x_obs[i] ~ categorical(softmax(score)); // likelihood statement for x categorical variable | |
y[i] ~ bernoulli_logit(beta1*x_cat_1[i] + beta2*x_cat_2[i] + beta3*x_cat_3[i]); // likelihood statement for outcome | |
} | |
} // close loop | |
} // close model block | |
/////////////////////////////////////////////////////////////////////////////// | |
generated quantities { // generate estimates of the imputed category for all observations | |
matrix[N, K] x_imp; // matrix to contain for all N observations the probability of belonging to category 1:K | |
for (i in 1:N) { | |
vector[K] score; // vector containing log-odds of being in category 1, 2, or 3 | |
score[1] = 0; // score[1] is fixed at zero | |
score[2] = a_cat2 + b_cat2*cov_for_x[i]; // modeling the log-odds of category 2 | |
score[3] = a_cat3 + b_cat3*cov_for_x[i]; // modeling the log-odds of category 3 | |
if (x_miss[i] == 1) { // x is unobserved | |
// want probability of the unobserved x value belonging to each category, | |
// given the observed y value | |
// so with K = 3 we want: Pr(1|y), Pr(2|y), and Pr(3|y) | |
// which is equivalent to: Pr(1,y)/Pr(y), Pr(2,y)/Pr(y), and Pr(3,y)/Pr(y) | |
vector[K] logPxy; // vector for Pr(1,y), Pr(2,y), and Pr(3,y) values | |
real logPy; // Pr(y) value | |
// calculate Pr(x,y) values for 1:K | |
logPxy[1] = log_softmax(score)[1] + bernoulli_logit_lpmf(y[i] | beta1); // Pr(1,y) = Pr(1)Pr(y|1) | |
logPxy[2] = log_softmax(score)[2] + bernoulli_logit_lpmf(y[i] | beta2); // Pr(2,y) = Pr(2)Pr(y|2) | |
logPxy[3] = log_softmax(score)[3] + bernoulli_logit_lpmf(y[i] | beta3); // Pr(3,y) = Pr(3)Pr(y|3) | |
// calculate Pr(y) (as in the model likelihood statement) | |
logPy = log_sum_exp(logPxy); // sum log probabilities across the scenarios (i.e., marginalize over missingness) | |
// populate the x_imp matrix row for the ith observation | |
x_imp[i, 1] = exp(logPxy[1] - logPy); // Pr(1|y) = Pr(1,y)/Pr(y) | |
x_imp[i, 2] = exp(logPxy[2] - logPy); // Pr(2|y) = Pr(2,y)/Pr(y) | |
x_imp[i, 3] = exp(logPxy[3] - logPy); // Pr(3|y) = Pr(3,y)/Pr(y) | |
} | |
else { // x is observed | |
x_imp[i, 1:3] = [0, 0, 0]; | |
x_imp[i, x_obs[i]] = 1; // when the category has been observed, we know the category | |
} | |
} // close loop | |
} // close model block | |
" | |
#============================================================================== | |
# Fit the data | |
fit_model <- | |
stan(model_code = model, | |
data = list(N = N, K = K, y = y, x_obs = x_obs, | |
x_miss = x_miss, cov_for_x = cov_for_x), | |
iter = 1000, chains = 4, cores = 4, | |
control = list(adapt_delta = 0.995, max_treedepth = 15)) | |
# Examine the model output | |
precis(fit_model) | |
out <- extract(fit_model) | |
# The model is correctly recovering the relationship between category | |
# and outcome... | |
logistic(mean(out$beta1)) # should be ~0.8 | |
logistic(mean(out$beta2)) # should be ~0.3 | |
logistic(mean(out$beta3)) # should be ~0.1 | |
# And is correctly recovering the relative probability of belonging to a | |
# category... | |
softmax(0, mean(out$a_cat2), mean(out$a_cat3)) | |
# should be ~0.5, ~0.3, ~0.2 | |
# Note that if the missingness highly skews the observed x distribution | |
# relative to the true x distribution, I do not think this will be the case... | |
# This model is also generating estimates for the probability of missing | |
# values belonging to different x categories | |
# Generate a summary of the posterior probabilities for category assignment | |
post.prob.means <- apply(out$x_imp, c(2, 3), mean) | |
# this is a N x K matrix showing the mean probabilities | |
# Show this matrix with only those rows representing missing x data | |
post.prob.means[x_miss == 1, ] | |
# And show this data along with the relevant y observations | |
cbind(post.prob.means[x_miss == 1, ], y[x_miss == 1]) | |
# this shows that the probabilities of belonging to a category | |
# shift depending upon the observed data |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment