Last active
December 7, 2020 22:59
-
-
Save tdunning/a31265b1d8001dba4222a7939e000414 to your computer and use it in GitHub Desktop.
Implementation of Monte Carlo EM algorithm for reconstructing a standard distribution from censored observations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### This is a demonstration of a Monte Carlo Expectation Maximization | |
### algorithm that can recover the mean and standard deviation of | |
### truncated normally distributed data. We get 10,000 samples from | |
### a unit normal distribution, but every sample below 0.5 is truncated | |
### to that value. Every sample above 2.5 is truncated to that value. | |
### These choices were made to get quick and visually appealling convergence | |
### but the algorithm still converges for any choice. The converges | |
### could be very, very slow if there is little information in the samples | |
### and the final answer could have substantial uncertainty. For instance, | |
### if we truncated at 4 and 6, almost all samples would be piled up at | |
### the lower limit and all we really can know is that the mean is less | |
### than the lower limit and the standard deviation is a small part of | |
### the distance from the mean to the lower limit. | |
### The way that this works is that we keep a current estimate of mean | |
### and standard deviation. At each step, we replace samples that were | |
### truncated with samples taken from our currently estimated | |
### distribution. Then we combine the uncensored samples with these | |
### synthetic samples and compute the new mean and standard deviation. | |
### This algorithm works very well even when our initial estimates are | |
### absurdly wrong. | |
### for a pretty show, try running demo() | |
### This code has a few nice tricks. One is the use of the CDF of the | |
### normal to get a quantile range. We can then take uniform samples | |
### from that range and use the inverse CDF to get suitably truncated | |
### samples from the normal. Without this, we might use an MCMC sampler | |
### based on Metropolis-Hastings. The CDF trick is just much faster and | |
### only takes a few lines of code. | |
### truncation bounds | |
a = 0.5 | |
b = 2.5 | |
### Raw data and truncated data | |
z.0 = rnorm(10000) | |
z = pmin(b, pmax(a, z.0)) | |
### Demo example | |
demo = function() { | |
## start with crazy estimate | |
ms = c(5, 1) | |
for (i in 1:100) { | |
ms = step(ms, z, a, b, T) | |
Sys.sleep(0.1) | |
} | |
} | |
### Record frames to show the evolution and then stitch into a video | |
video = function() { | |
## start with crazy estimate | |
ms = c(-5, 0.1) | |
system("rm -rf frames") | |
system("mkdir frames") | |
for (i in 1:200) { | |
png(sprintf("frames/f-%04d.png", i), 1920, 1080, pointsize=24) | |
ms = step(ms, z, a, b, T, lwd=15) | |
dev.off() | |
} | |
system("rm mcem.mp4") | |
system("ffmpeg -r 10 -f image2 -s 1920x1080 -i frames/f-%04d.png -vcodec libx264 -crf 25 -pix_fmt yuv420p mcem.mp4") | |
} | |
### This is where the Monte Carlo E step and the M step are done | |
step = function(ms, data, a, b, plot=T, lwd=5) { | |
## Unpack current mean and sd | |
m = ms[1] | |
s = ms[2] | |
## Find the censored samples | |
low = data <= a | |
high = data >= b | |
uncensored = data[(!low) & (!high)] | |
## Find bounds for resampling the censored data | |
## we have to avoid returning exactly 0 or 1 here | |
p.a = max(1e-8, min(1-1e-8, pnorm(a, m, s))) | |
p.b = max(1e-8, min(1-1e-8, pnorm(b, m, s))) | |
## transform uniform samples into normally distributed | |
## these should be in the censored regions, but we | |
## touch them up a bit if they encroached. That happens | |
## if our current distribution is crazy | |
re.a = pmin(a, qnorm(runif(sum(low), 0, p.a), m, s)) | |
re.b = pmax(b, qnorm(runif(sum(high), p.b, 1), m, s)) | |
re.data = c(re.a, uncensored, re.b) | |
if (plot) { | |
## plots a histogram on constant range of x axis. Original samples are | |
## overwritten with red | |
brk = seq(-5.1, 5.1, by=0.1) | |
h = hist(re.data[(re.data>-5) & (re.data<5)], breaks=brk, plot=F) | |
plot(c(), c(), xlim=c(-5,5), ylim=c(0,500), xlab="x", ylab="count") | |
text(2.8, 300, sprintf("mean = %.2f", ms[1]), adj=0) | |
text(2.8, 280, sprintf("sd = %.3f", ms[2]), adj=0) | |
text(2.8, 260, sprintf("(original m=%.2f, sd=%.3f)", mean(z.0), sd(z.0)), adj=0) | |
legend(2.7, 370, legend=c("Observed data", "Monte Carlo data"), | |
fill=c('red','black')) | |
col = rgb((h$mids > a) & (h$mids < b), 0, 0, alpha=0.8) | |
lines(h$mids, h$counts, type='h', lwd=lwd, col=col, ylim=c(0,500)) | |
} | |
## compute new estimates | |
c(mean(re.data), sd(re.data)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment