Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Created October 24, 2022 20:20
Show Gist options
  • Save vankesteren/8ba9aab6f4e2dbe940de459c6d8ffd0d to your computer and use it in GitHub Desktop.
Save vankesteren/8ba9aab6f4e2dbe940de459c6d8ffd0d to your computer and use it in GitHub Desktop.
Univariate t-mixture modelling with R. Also works for gaussian! (albeit a bit slow)
# t-mixture modeling with EM
priorp <- 0.6
m1 <- 0
m2 <- 3
s1 <- 1
s2 <- 0.707
df1 <- 2
df2 <- Inf
# generate some data with 2 classes
N <- 1000
cl <- rbinom(N, 1, priorp)
x <- cl*(rt(N, df = df1) + m1)*s1 + (1-cl)*(rt(N, df = df2) + m2)*s2
# Maximum likelihood estimates of location-scale t-distribution
t_mle <- function(x, w) {
# p[1]: mean, p[2]: log(sd), p[3]: log(df - 1)
res <- optim(
par = c(0, 0, 1),
fn = function(p, x, df, w) {
ll <- dt((x - p[1])/exp(p[2]), df = exp(p[3]) + 1, log = TRUE) - p[2]
sum(ll*w)
},
method = "BFGS",
control = list(fnscale = -1), # max, not min
x = x,
df = df,
w = w
)
if (res$convergence != 0) warning(res$message)
return(list(mu = res$par[1], sigma = exp(res$par[2]), df = exp(res$par[3]) + 1))
}
# the e-step, compute posterior probabilities from density
estep <- function(x, theta, K = 2) {
d <- matrix(0.0, length(x), K)
for (k in 1:K) {
d[,k] <- dt((x - theta$m[k])/theta$s[k], df = theta$df[k])
}
t(apply(d, 1, function(x) x/sum(x)))
}
# the m-step, compute parameters using posterior probability
mstep <- function(x, postp, df = 2, K = 2) {
s <- numeric(K)
m <- numeric(K)
df <- numeric(K)
for (k in 1:K) {
res_k <- t_mle(x, w = postp[,k])
s[k] <- res_k$sigma
m[k] <- res_k$mu
df[k] <- res_k$df
}
return(list(m = m, s = s, df = df))
}
# plotting function
dmix <- function(x, postp, theta) {
priorp <- colMeans(postp)
d <- 0
for (k in 1:length(priorp)) {
d <- d + priorp[k]*dt((x - theta$m[k]) / theta$s[k], df = theta$df[k])
}
d
}
# initial values
theta <- list(
m = c(-1, 1),
s = c(1, 1),
df = c(4, 4)
)
# run EM
for (i in 1:100) {
postp <- estep(x, theta, K = 2)
theta <- mstep(x, postp, K = 2)
}
# plot
hist(x, freq = FALSE, breaks = "FD", xlim = c(-5, 7), main = paste("Iteration", i))
curve(dmix(x, postp, theta), add = TRUE, n = 1000, from = -6, to = 8)
# check out the parameters
theta
colMeans(postp)
@vankesteren
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment