Skip to content

Instantly share code, notes, and snippets.

@eveskew
Created December 4, 2018 20:12
Show Gist options
  • Save eveskew/cbce607e252638f5ebf082c634fa814d to your computer and use it in GitHub Desktop.
Save eveskew/cbce607e252638f5ebf082c634fa814d to your computer and use it in GitHub Desktop.
Discrete missing values in Stan when the number of categories > 2
# Categorical missing data in Stan
# Demonstrated with three categories
# Code and ideas adapted from @rmcelreath
# https://gist.github.com/rmcelreath/9406643583a8c99304e459e644762f82
# Code and ideas discussed with @dmontecino
# https://gist.github.com/dmontecino/b804853e4b36a57990a7108a35201cf5
library(rstan)
library(rethinking) # devtools::install_github("rmcelreath/rethinking")
#==============================================================================
# Data setup
# Note: in this simulated data, categories are unbalanced and the data is
# missing at random
N <- 5000
N_missing <- 100
K <- 3 # number of categories
# Generate unordered categorical covariates, with unbalanced observations
# Category 1 = 50% of simulated data
# Category 2 = 30% of simulated data
# Category 3 = 20% of simulated data
x <- rmultinom(N, size = 1, prob = c(0.5, 0.3, 0.2))
x <- sapply(1:N, function(i) which(x[ , i] == 1))
# Simulate bivariate response as a function of the category
y <- rep(NA, N)
for (i in 1:N) {
if (x[i] == 1)
y[i] = rbinom(n = 1, size = 1, prob = 0.8)
else if (x[i] == 2)
y[i] = rbinom(n = 1, size = 1, prob = 0.3)
else # x = 3
y[i] = rbinom(n = 1, size = 1, prob = 0.1)
}
# Simulate missing data
i_miss <- sample(1:N, size = N_missing)
x_obs <- x
x_obs[i_miss] <- (-1) # placeholder, Stan will not accept NA values
x_miss <- ifelse(1:N %in% i_miss, 1, 0)
# Create a covariate for use in prediction of the categorical variable
# Note: in this case, the covariate does not actually correlate with the
# category, but the code is such that this could be easily changed
cov_for_x <- NA
cov_for_x[x_obs == 1] <-
rbinom(length(cov_for_x[x_obs == 1]), size = 1, prob = 0.5)
cov_for_x[x_obs == 2] <-
rbinom(length(cov_for_x[x_obs == 2]), size = 1, prob = 0.5)
cov_for_x[x_obs == 3] <-
rbinom(length(cov_for_x[x_obs == 3]), size = 1, prob = 0.5)
cov_for_x[is.na(cov_for_x)] <-
rbinom(length(cov_for_x[is.na(cov_for_x)]), size = 1, prob = 0.5)
# Create dummy variables for use in the model
x_cat_1 <- ifelse(x_obs == 1, 1, 0)
x_cat_2 <- ifelse(x_obs == 2, 1, 0)
x_cat_3 <- ifelse(x_obs == 3, 1, 0)
cov_for_x_for_miss_cat <- cov_for_x[x_miss == 1]
# Examine distribution of full simulated data
simplehist(x)
# Examine distribution of simulated data that goes unobserved
simplehist(x[x_miss == 1])
# So in this case, the categories are unbalanced, but the unobserved data
# reflects this same pattern
#==============================================================================
# Model definition
model <- "
data {
int N; // number of observations
int K; // number of categories
int y[N]; // binary outcome variable
int x_obs[N]; // observed categorical variable (-1 when unobserved)
int x_miss[N]; // dummy variable indicating missingness in the categorical variable
int cov_for_x[N]; // a binary predictor for the x categorical variable
int x_cat_1[N]; // dummy variable indicating when x has the first level
int x_cat_2[N]; // dummy variable indicating when x has the second level
int x_cat_3[N]; // dummy variable indicating when x has the third level
}
///////////////////////////////////////////////////////////////////////////////
parameters {
real beta1; // coefficient for categorical variable level 1
real beta2; // coefficient for categorical variable level 2
real beta3; // coefficient for categorical variable level 3
real a_cat2; // intercept to model the probability of category 2
real a_cat3; // intercept to model the probability of category 3
real b_cat2; // coefficient to model the probability of category 2
real b_cat3; // coefficient to model the probability of category 3
}
///////////////////////////////////////////////////////////////////////////////
model {
// priors for all parameters
beta1 ~ normal(0, 1);
beta2 ~ normal(0, 1);
beta3 ~ normal(0, 1);
a_cat2 ~ normal(0, 1);
a_cat3 ~ normal(0, 1);
b_cat2 ~ normal(0, 1);
b_cat3 ~ normal(0, 1);
// likelihood
for (i in 1:N) {
vector[K] score; // vector containing log-odds of being in category 1, 2, or 3
score[1] = 0; // score[1] is fixed at zero
score[2] = a_cat2 + b_cat2*cov_for_x[i]; // modeling the log-odds of category 2
score[3] = a_cat3 + b_cat3*cov_for_x[i]; // modeling the log-odds of category 3
if (x_miss[i] == 1) { // x is unobserved
// model the binary outcome, marginalizing over missingness
vector[K] logPxy; // vector to hold the log probabilities for each alternate scenario (category 1, 2, or 3)
// log_softmax is a vector of log probabilities for each x category
logPxy[1] = log_softmax(score)[1] + bernoulli_logit_lpmf(y[i] | beta1); // category 1
logPxy[2] = log_softmax(score)[2] + bernoulli_logit_lpmf(y[i] | beta2); // category 2
logPxy[3] = log_softmax(score)[3] + bernoulli_logit_lpmf(y[i] | beta3); // category 3
target += log_sum_exp(logPxy); // sum log probabilities across the scenarios (i.e., marginalize over missingness)
}
else { // x is observed
x_obs[i] ~ categorical(softmax(score)); // likelihood statement for x categorical variable
y[i] ~ bernoulli_logit(beta1*x_cat_1[i] + beta2*x_cat_2[i] + beta3*x_cat_3[i]); // likelihood statement for outcome
}
} // close loop
} // close model block
///////////////////////////////////////////////////////////////////////////////
generated quantities { // generate estimates of the imputed category for all observations
matrix[N, K] x_imp; // matrix to contain for all N observations the probability of belonging to category 1:K
for (i in 1:N) {
vector[K] score; // vector containing log-odds of being in category 1, 2, or 3
score[1] = 0; // score[1] is fixed at zero
score[2] = a_cat2 + b_cat2*cov_for_x[i]; // modeling the log-odds of category 2
score[3] = a_cat3 + b_cat3*cov_for_x[i]; // modeling the log-odds of category 3
if (x_miss[i] == 1) { // x is unobserved
// want probability of the unobserved x value belonging to each category,
// given the observed y value
// so with K = 3 we want: Pr(1|y), Pr(2|y), and Pr(3|y)
// which is equivalent to: Pr(1,y)/Pr(y), Pr(2,y)/Pr(y), and Pr(3,y)/Pr(y)
vector[K] logPxy; // vector for Pr(1,y), Pr(2,y), and Pr(3,y) values
real logPy; // Pr(y) value
// calculate Pr(x,y) values for 1:K
logPxy[1] = log_softmax(score)[1] + bernoulli_logit_lpmf(y[i] | beta1); // Pr(1,y) = Pr(1)Pr(y|1)
logPxy[2] = log_softmax(score)[2] + bernoulli_logit_lpmf(y[i] | beta2); // Pr(2,y) = Pr(2)Pr(y|2)
logPxy[3] = log_softmax(score)[3] + bernoulli_logit_lpmf(y[i] | beta3); // Pr(3,y) = Pr(3)Pr(y|3)
// calculate Pr(y) (as in the model likelihood statement)
logPy = log_sum_exp(logPxy); // sum log probabilities across the scenarios (i.e., marginalize over missingness)
// populate the x_imp matrix row for the ith observation
x_imp[i, 1] = exp(logPxy[1] - logPy); // Pr(1|y) = Pr(1,y)/Pr(y)
x_imp[i, 2] = exp(logPxy[2] - logPy); // Pr(2|y) = Pr(2,y)/Pr(y)
x_imp[i, 3] = exp(logPxy[3] - logPy); // Pr(3|y) = Pr(3,y)/Pr(y)
}
else { // x is observed
x_imp[i, 1:3] = [0, 0, 0];
x_imp[i, x_obs[i]] = 1; // when the category has been observed, we know the category
}
} // close loop
} // close model block
"
#==============================================================================
# Fit the data
fit_model <-
stan(model_code = model,
data = list(N = N, K = K, y = y, x_obs = x_obs,
x_miss = x_miss, cov_for_x = cov_for_x),
iter = 1000, chains = 4, cores = 4,
control = list(adapt_delta = 0.995, max_treedepth = 15))
# Examine the model output
precis(fit_model)
out <- extract(fit_model)
# The model is correctly recovering the relationship between category
# and outcome...
logistic(mean(out$beta1)) # should be ~0.8
logistic(mean(out$beta2)) # should be ~0.3
logistic(mean(out$beta3)) # should be ~0.1
# And is correctly recovering the relative probability of belonging to a
# category...
softmax(0, mean(out$a_cat2), mean(out$a_cat3))
# should be ~0.5, ~0.3, ~0.2
# Note that if the missingness highly skews the observed x distribution
# relative to the true x distribution, I do not think this will be the case...
# This model is also generating estimates for the probability of missing
# values belonging to different x categories
# Generate a summary of the posterior probabilities for category assignment
post.prob.means <- apply(out$x_imp, c(2, 3), mean)
# this is a N x K matrix showing the mean probabilities
# Show this matrix with only those rows representing missing x data
post.prob.means[x_miss == 1, ]
# And show this data along with the relevant y observations
cbind(post.prob.means[x_miss == 1, ], y[x_miss == 1])
# this shows that the probabilities of belonging to a category
# shift depending upon the observed data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment