-
-
Save mattjj/f2d4776650d7a0988ab51886dee9cee4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function [alpha, k1, k2] = em(data, alpha, k1, k2, tol) | |
if ~exist('tol', 'var'); tol = 1e-5; end | |
div = @(x, y) bsxfun(@rdivide, x, y); | |
mul = @(x, y) bsxfun(@times, x, y); | |
normalize = @(X) div(X, sum(X, 1)); | |
loglike = @(k1, k2) mul([k1; k2], exp(mul(-[k1; k2], data))); | |
while true | |
% E step | |
R = normalize(mul([alpha; 1 - alpha], loglike(k1, k2))); | |
% M step | |
alpha_next = sum(R(1, :)) / sum(R(:)); | |
k1_next = sum(R(1, :)) ./ (R(1, :) * data'); | |
k2_next = sum(R(2, :)) ./ (R(2, :) * data'); | |
% convergence check | |
if norm([alpha, k1, k2] - [alpha_next, k1_next, k2_next]) < tol | |
break | |
end | |
[alpha, k1, k2] = multiassign([alpha_next, k1_next, k2_next]); | |
end | |
end | |
function varargout = multiassign(x) | |
varargout = num2cell(x); | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function data = synth_data(alpha, k1, k2, N) | |
k = [k1, k2]; | |
labels = 1 + (rand(1, N) > alpha); % 1 with probability alpha, otherwise 2 | |
data = exprnd(1 ./ k(labels)); % different parameter convention, k <-> 1/k | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
rng(0); % set the random seed for reproducibility | |
% generate synthetic data | |
alpha = 0.75; | |
k1 = 2; | |
k2 = 10; | |
N = 1000; | |
data = synth_data(alpha, k1, k2, N); | |
% try running EM initialized from the truth | |
[alpha_trueinit, k1_trueinit, k2_trueinit] = em(data, alpha, k1, k2); | |
% try running EM initialized from a guess | |
alpha_guess = 0.5; | |
k1_guess = 3; | |
k2_guess = 5; | |
[alpha_hat, k1_hat, k2_hat] = em(data, alpha_guess, k1_guess, k2_guess); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment