-
-
Save xiaohan2012/9133810 to your computer and use it in GitHub Desktop.
Expectation Maximization algorithm for Gaussian Mixture Model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%% ----------------------------------------------------------- | |
%% The EM part! | |
%% ----------------------------------------------------------- | |
function [means, stds, P] = em (X, K) | |
%input: | |
%% X, the data points | |
%% K, the component number | |
[N, featureNumber] = size (X); | |
%randomly initialize the parameter | |
means = rand (K, featureNumber); | |
stds = rand (K, 1); | |
temp = rand (1, K); | |
P = temp / sum (temp); %normalize it | |
convergent = 0; %indicator variable, whether convergent or not | |
p = zeros (N, K); %N x K matrix indicating the probabilities of the n-th record being generated by the k-th component | |
g = zeros (N, K); %the joint probability densities of drawing the n-th record from the k-th component | |
newMeans = zeros (K, featureNumber); | |
newStds = zeros (K, 1); | |
newP = zeros (1, K); | |
while !convergent | |
%%--------------------------- | |
%% Expect the membership prob | |
%%--------------------------- | |
%for each component, compute the density value of drawing the sample points | |
%and fill them in the columns | |
for k = 1:K | |
g (:, k) = npdf (X, means (k, :), stds (k)); | |
end | |
%expect the membership prob, p, for each component | |
for k = 1:K | |
p (:, k) = P (k) * g (:, k); | |
end | |
p = p ./ repmat(sum (p, 2), 1, K); %normalize it | |
%%-------------------- | |
%% Maximize the likelihood function | |
%% ------------------- | |
for k = 1:K | |
newMeans (k, :) = sum(repmat(p (:, k), 1, featureNumber) .* X, 1) / sum(p (:, k)); | |
newStds (k) = sqrt ( sum (p (:, k) .* sum((X - repmat(means(k,:), N, 1)).^2, 2) / (featureNumber * sum(p (:, k))))); | |
end | |
newP = sum(p, 1) / N; | |
meansDiff = max(abs(newMeans - means)); | |
stdsDiff = max(abs(newStds - stds)); | |
PDiff = max(abs(newP - P)); | |
if meansDiff < 0.001 || stdsDiff < .0001 || PDiff < .0001 | |
convergent = 1; | |
end | |
means = newMeans; | |
stds = newStds; | |
P = newP; | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function samples = gendata (means, stds, P, N) | |
%assuming no correlation between each dimension | |
featureNumber = size(means)(2); | |
componentNumber = length (P); | |
%% sample N times from the three models according to distribution P | |
whichModel = zeros (1, N); | |
randNums = rand (1, N); | |
lower = 0; upper = 0; | |
for i = 1:componentNumber | |
if i == 1 | |
lower = 0.0; | |
else | |
lower = upper; | |
end | |
upper = upper + P (i); | |
whichModel(randNums >= lower & randNums < upper) = i; | |
end | |
modelMeans = zeros (N, featureNumber);%NxfeatureNumber matrix of the model means | |
modelStds = zeros (N, 1); %Nx1 matrix of the model standard deviations | |
for i = 1: componentNumber | |
selector = find(whichModel == i); | |
modelMeans (selector, :) = repmat(means (i, :), length (selector), 1); | |
modelStds (selector, :) = repmat(stds (i, :), length (selector), 1); | |
end | |
%% draw N random data points according to the model selection using the $z * std + mean$ method | |
samples = randn (N, featureNumber) .* repmat(modelStds, 1, 2) + modelMeans; | |
%plot the scatter plot | |
%% styles = ['s', 'r', 'x']; | |
%% for i = 1:componentNumber | |
%% scatter (samples (whichModel == i, 1), samples (whichModel == i, 2), styles (i)); | |
%% hold on; | |
%% end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
means = [1, 1; 2, 2; 3, 3]; | |
stds = [0.1; 0.1; 0.1]; | |
P = [0.3, 0.5, 0.2]; %% the cumulative probabilities | |
N = 1000; | |
X = gendata (means, stds, P, N); | |
K = 3; %three components | |
%Let's EM! | |
[emeans, estds, eP] = em (X, K) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment