Skip to content

Instantly share code, notes, and snippets.

@chris-taylor
Created December 5, 2012 10:07
Show Gist options
  • Select an option

  • Save chris-taylor/4214454 to your computer and use it in GitHub Desktop.

Select an option

Save chris-taylor/4214454 to your computer and use it in GitHub Desktop.
Expectation Maximization
function em(X,theta)
% Expectation maximization, P coins
%
% X is TxN matrix of coin flip results (1 = heads, 0 = tails)
% theta is 1xP vector of probabilities (0 < theta < 1)
% Convergence criterion (relative difference)
tol = 1e-6;
% Compute parameters of distribution assigning coins to outputs
T = size(X,1);
N = size(X,2);
K = sum(X,2);
P = length(theta);
fprintf('%10s %10s %10s %10s\n','Step','P(a)','P(b)','RelDiff')
% Loop until convergence
relativeDifference = inf;
step = 1;
while relativeDifference > tol
% Output
fprintf('%10d %9.4f%% %9.4f%% %10.6f\n',step,100*theta(1),100*theta(2),relativeDifference)
oldTheta = theta;
step = step + 1;
% Compute the likelihood weights for assignment of coins to
% observations.
W = zeros(T,P);
for t = 1:T
W(t,:) = exp(ll(X(t,:),oldTheta));
end
W = bsxfun(@rdivide,W,sum(W,2));
% Compute expectation using likelihood weights (in this case that
% means assigning outcomes to coins).
Hd = sum(bsxfun(@times, W, K));
Tl = sum(bsxfun(@times, W, N-K));
% Maximization step (analytically soluble in this case).
theta = Hd ./ (Hd + Tl);
% Compute relative difference to check for convergence.
relativeDifference = norm(theta./oldTheta - 1);
end
end
function ll = ll(X,p)
n = length(X);
k = sum(X);
ll = lognchoosek(n,k) + k * log(p) + (n-k) * log(1-p);
end
function p = lognchoosek(n,k)
p = gammaln(n+1) - gammaln(k+1) - gammaln(n+k-1);
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment