Created
September 4, 2023 13:12
-
-
Save MaverickMeerkat/838af93586a130347acc59e22a568e5b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Expectation Maximization - Gaussian Mixture Model | |
library(ks) # for kde | |
library(mvtnorm) # for MVNormal | |
library(extraDistr) # for Categorical & Dirichlet distribution | |
library(ggplot2) # for plotting | |
library(pracma) # for sqrtm | |
###################### | |
# 1D, 2-components # | |
###################### | |
## sigma's, p are known; only estimate mu's | |
set.seed(247) | |
p = 0.3 | |
n = 2000 | |
# Latent | |
z = rbinom(n, 1, p) | |
# Observed (generated) | |
mu1 = 2 | |
sig1 = sqrt(2) | |
mu0 = -2 | |
sig0 = sqrt(1) | |
x = rnorm(n,mu1,sig1)*z+rnorm(n,mu0,sig0)*(1-z) | |
# Plot data histogram | |
hist(x, breaks=50) | |
# Plot data KDE | |
plot(kde(x)) | |
# E step | |
doE = function (theta) { | |
mu0 = theta[1] | |
mu1 = theta[2] | |
E = rep(NA, n) | |
for (i in 1:n) { | |
a = dnorm(x[i], mean=mu1, sd=sig1)*p | |
b = dnorm(x[i], mean=mu0, sd=sig0)*(1-p) | |
E[i] = a/(a+b) | |
} | |
return(E) | |
} | |
# M step | |
doM = function(E) { | |
theta = rep(NA, 2) | |
theta[1] = sum((1-E)*x)/sum(1-E) | |
theta[2] = sum(E*x)/sum(E) | |
return(theta) | |
} | |
EM = function(theta, maxIter=50, tol=1e-5) { | |
theta.t = theta | |
for (i in 1:maxIter) { | |
E = doE(theta.t) | |
theta.t = doM(E) | |
cat("Iteration: ", i, "pt: ", theta.t, "\n") | |
if (norm(theta.t-theta, type="2") < tol) break | |
theta=theta.t | |
} | |
return(theta.t) | |
} | |
# Initial values | |
theta0 = rmvnorm(1, mean=c(0,0)) | |
(theta.final = EM(theta0)) | |
## estimate p, mu's & sigma's | |
# E step | |
doE = function (theta) { | |
mu0 = theta[1] | |
sig0 = sqrt(theta[2]) | |
mu1 = theta[3] | |
sig1 = sqrt(theta[4]) | |
p = theta[5] | |
E = rep(NA, n) | |
for (i in 1:n) { | |
a = dnorm(x[i], mean=mu1, sd=sig1)*p | |
b = dnorm(x[i], mean=mu0, sd=sig0)*(1-p) | |
E[i] = a/(a+b) | |
} | |
return(E) | |
} | |
# M step | |
doM = function(E) { | |
theta = rep(NA, 5) | |
theta[1] = sum((1-E)*x)/sum(1-E) | |
theta[2] = sum((1-E)*(x-theta[1])^2)/sum(1-E) | |
theta[3] = sum(E*x)/sum(E) | |
theta[4] = sum(E*(x-theta[3])^2)/sum(E) | |
theta[5] = mean(E) | |
return(theta) | |
} | |
EM = function(theta, maxIter=200, tol=1e-5) { | |
theta.t = theta | |
for (i in 1:maxIter) { | |
E = doE(theta.t) | |
theta.t = doM(E) | |
cat("Iteration: ", i, "pt: ", theta.t, "\n") | |
if (norm(theta.t-theta, type="2") < tol) break | |
theta=theta.t | |
} | |
return(theta.t) | |
} | |
# Initial values | |
theta0 = c(0,0.5,0,3,0.7) # mu0, sig0.2, mu1, sig1.2, p | |
(theta.final = EM(theta0)) | |
################## | |
# General Case # | |
################## | |
# K x 2D Gaussians | |
set.seed(247) | |
# Params | |
K = 3 | |
phis = rdirichlet(1, rep(1, K)) | |
j = 2 # how far apart the centers are | |
mus = matrix(c(j,-j,0,j,j,-j),ncol=2) # rmvnorm(K, mean=c(0,0)) | |
Sigmas = vector("list", K) | |
for (k in 1:K) { | |
mat = matrix(rnorm(100), ncol=2) | |
Sigmas[[k]] = cov(mat) | |
} | |
# Latent | |
z = rcat(n, phis) | |
# Observed (generated) | |
x = matrix(nrow=n, ncol=2) | |
for (i in 1:n) { | |
k = z[i] | |
mu = mus[k,] | |
sigma = Sigmas[[k]] | |
x[i,] = rmvnorm(1, mean=mu, sigma=sigma) | |
} | |
# plot the data | |
df = data.frame(x=x, mu=as.factor(z)) | |
(plt = ggplot(df, aes(x=x[,1], y=x[,2], color=mu, fill=mu)) + geom_point()) | |
# E step | |
doE = function (theta) { | |
E = with(theta, do.call(cbind, lapply(1:K, function(k) phis[[k]]*dmvnorm(x, mus[[k]], Sigmas[[k]])))) | |
E/rowSums(E) | |
} | |
doM = function(E) { | |
phis = colMeans(E) | |
covs = lapply(1:K, function(k) cov.wt(x, E[,k], method="ML")) | |
mus = lapply(covs, "[[", "center") | |
sig = lapply(covs, "[[", "cov") | |
return(list(mus=mus, Sigmas=sig, phis=phis)) | |
} | |
logLikelihood = function(theta) { | |
probs = with(theta, do.call(cbind, lapply(1:K, function(i) phis[i] * dmvnorm(x, mus[[i]], Sigmas[[i]])))) | |
sum(log(rowSums(probs))) | |
} | |
EM = function(theta, maxIter=30, tol=1e-1) { | |
theta.t = theta | |
for (i in 1:maxIter) { | |
E = doE(theta.t) | |
theta.t = doM(E) | |
ll.diff = logLikelihood(theta.t) - logLikelihood(theta) | |
cat("Iteration: ", i, " ll difference: ", ll.diff, "\n") | |
if (abs(ll.diff) < tol) break | |
theta=theta.t | |
} | |
return(theta.t) | |
} | |
# Initial values | |
set.seed(1) | |
phis0 = rdirichlet(1, rep(1, K)) | |
mus0 = vector("list", K) | |
Sigmas0 = vector("list", K) | |
for (k in 1:K) { | |
mat = matrix(rnorm(100), ncol=2) | |
mus0[[k]]=rmvnorm(1, mean=c(3,-3)) | |
Sigmas0[[k]] = cov(mat) | |
} | |
theta0 = list(mus=mus0, Sigmas=Sigmas0, phis=phis0) | |
(theta.final = EM(theta0)) | |
# Plot the two distributions | |
circleFun = function(center=c(0,0), diameter=1, npoints=100){ | |
r = diameter / 2 | |
tt = seq(0,2*pi,length.out = npoints) | |
xx = center[1] + r * cos(tt) | |
yy = center[2] + r * sin(tt) | |
return(data.frame(x = xx, y = yy)) | |
} | |
plotCircle = function(plt, center, Sigma, col="#000000") { | |
dat = circleFun(c(0,0),4,npoints = 100) | |
dat1 = sweep(as.matrix(dat) %*% sqrtm(Sigma)$B, MARGIN=2, center, "+") | |
plt = plt + theme_light() + theme(legend.position="none") + | |
ylim(-6,6) + xlim(-6,6) + xlab("x") + ylab("y") + | |
geom_point(mapping=aes(x=center[1],y=center[2]), color="#000000") + | |
geom_polygon(as.data.frame(dat1), mapping=aes(x=V1,y=V2), color=col, fill=col, alpha=0.3) | |
return(plt) | |
} | |
plt = ggplot(df, aes(x=x[,1], y=x[,2], color=mu, fill=mu)) + geom_point() | |
plt = plotCircle(plt, theta.final$mus[[1]], theta.final$Sigmas[[1]]) | |
plt = plotCircle(plt, theta.final$mus[[2]], theta.final$Sigmas[[2]]) | |
plt = plotCircle(plt, theta.final$mus[[3]], theta.final$Sigmas[[3]]) | |
(plt) | |
# D. Refaeli | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment