Skip to content

Instantly share code, notes, and snippets.

@rmcelreath
Created March 8, 2017 10:39
Show Gist options
  • Save rmcelreath/9406643583a8c99304e459e644762f82 to your computer and use it in GitHub Desktop.
Save rmcelreath/9406643583a8c99304e459e644762f82 to your computer and use it in GitHub Desktop.
Discrete missing values in Stan
# "impute" missing binary predictor
# really just marginalizes over missingness
# imputed values produced in generated quantities
N <- 1000 # number of cases
N_miss <- 100 # number missing values
x_baserate <- 0.25 # prob x==1 in total sample
a <- 0 # intercept in y ~ N( a+b*x , 1 )
b <- 1 # slope in y ~ N( a+b*x , 1 )
# simulate data
x <- sample( 0:1 , size=N , replace=TRUE , prob=c(1-x_baserate,x_baserate) )
i_miss <- sample( 1:N , size=N_miss )
x_obs <- x
x_obs[i_miss] <- (-1)
x_NA <- x_obs
x_NA[i_miss] <- NA
x_miss <- ifelse( 1:N %in% i_miss , 1 , 0 )
y <- rnorm( N , a + b*x , 1 )
m_code <- "
data{
int N;
int x[N];
int x_miss[N];
real y[N];
}
parameters{
real a;
real b;
real<lower=0,upper=1> x_mu;
}
model{
a ~ normal(0,10);
b ~ normal(0,1);
x_mu ~ beta(1,1);
for ( i in 1:N ) {
if ( x_miss[i]==1 ) {
// x missing
target += log_mix( x_mu ,
normal_lpdf( y[i] | a + b , 1 ),
normal_lpdf( y[i] | a , 1 )
);
} else {
// x not missing
x[i] ~ bernoulli(x_mu);
y[i] ~ normal( a + b*x[i] , 1 );
}
}//i
}//model
generated quantities{
vector[N] x_impute;
for ( i in 1:N ) {
real logPxy;
real logPy;
if ( x_miss[i]==1 ) {
// need P(x|y)
// P(x|y) = P(x,y)/P(y)
// P(x,y) = P(x)P(y|x)
// P(y) = P(x==1)P(y|x==1) + P(x==0)P(y|x==0)
logPxy = log(x_mu) + normal_lpdf(y[i]|a+b,1);
logPy = log_mix( x_mu ,
normal_lpdf( y[i] | a + b , 1 ),
normal_lpdf( y[i] | a , 1 ) );
x_impute[i] = exp( logPxy - logPy );
} else {
x_impute[i] = x[i];
}
}//i
}//gq
"
library(rethinking)
m <- stan(
model_code=m_code ,
data=list(N=N,y=y,x=x_obs,x_miss=x_miss),
chains=1 )
precis(m)
# show imputed medians
post <- extract.samples(m)
Px <- apply(post$x_impute,2,median)
pt1 <- mean( Px[x_miss==1 & x==1] )
pt0 <- mean( Px[x_miss==1 & x==0] )
plot( x[x_miss==1] , Px[x_miss==1] , ylim=c(0,1) , xlab="true" , ylab="imputed probability == 1" )
points( c(0,1) , c(pt0,pt1) , pch=16 , col="red" )
@MadathilSA
Copy link

Hi Dr.McElreath,

Thank you for the code.
Which missing value mechanism is assumed in this case. That is when we marginalize over discrete missing values?
Is it Missing Completely at Random?

Thanks
Sreenath

@erik-ringen
Copy link

This is awesome. Would it be straightforward to extend this to non-binary discrete values?

@dmontecino
Copy link

With the help of @erik-ringen for those interested in more than 2 categories
https://gist.github.com/dmontecino/b804853e4b36a57990a7108a35201cf5

@eveskew
Copy link

eveskew commented Dec 4, 2018

And, following discussion with @dmontecino, one more attempt:
https://gist.github.com/eveskew/cbce607e252638f5ebf082c634fa814d
In the generated quantities block, I compute a matrix representing the probability that a missing x value belongs to each category, given the observed y outcome.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment