Skip to content

Instantly share code, notes, and snippets.

@mattjj
Last active October 10, 2017 00:04
Show Gist options
  • Save mattjj/f2d4776650d7a0988ab51886dee9cee4 to your computer and use it in GitHub Desktop.
Save mattjj/f2d4776650d7a0988ab51886dee9cee4 to your computer and use it in GitHub Desktop.
function [alpha, k1, k2] = em(data, alpha, k1, k2, tol)
if ~exist('tol', 'var'); tol = 1e-5; end
div = @(x, y) bsxfun(@rdivide, x, y);
mul = @(x, y) bsxfun(@times, x, y);
normalize = @(X) div(X, sum(X, 1));
loglike = @(k1, k2) mul([k1; k2], exp(mul(-[k1; k2], data)));
while true
% E step
R = normalize(mul([alpha; 1 - alpha], loglike(k1, k2)));
% M step
alpha_next = sum(R(1, :)) / sum(R(:));
k1_next = sum(R(1, :)) ./ (R(1, :) * data');
k2_next = sum(R(2, :)) ./ (R(2, :) * data');
% convergence check
if norm([alpha, k1, k2] - [alpha_next, k1_next, k2_next]) < tol
break
end
[alpha, k1, k2] = multiassign([alpha_next, k1_next, k2_next]);
end
end
function varargout = multiassign(x)
varargout = num2cell(x);
end
function data = synth_data(alpha, k1, k2, N)
k = [k1, k2];
labels = 1 + (rand(1, N) > alpha); % 1 with probability alpha, otherwise 2
data = exprnd(1 ./ k(labels)); % different parameter convention, k <-> 1/k
end
rng(0); % set the random seed for reproducibility
% generate synthetic data
alpha = 0.75;
k1 = 2;
k2 = 10;
N = 1000;
data = synth_data(alpha, k1, k2, N);
% try running EM initialized from the truth
[alpha_trueinit, k1_trueinit, k2_trueinit] = em(data, alpha, k1, k2);
% try running EM initialized from a guess
alpha_guess = 0.5;
k1_guess = 3;
k2_guess = 5;
[alpha_hat, k1_hat, k2_hat] = em(data, alpha_guess, k1_guess, k2_guess);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment