Last active
October 3, 2019 22:18
-
-
Save korymath/86ac78e712d3c49b791dc81da31ab1c2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function [current_error] = average_data_distance_error(n_groups, memberships, distances) | |
% Calculate average distance between data in clusters. | |
sum_error = 0; | |
for i_group = 1:n_groups | |
% indices of datapoints belonging to group i_group | |
i = find(memberships == i_group); | |
% cell array with N vectors to combine | |
elements = {i, i}; | |
% set up the result | |
combinations = cell(1, numel(elements)); | |
[combinations{:}] = ndgrid(elements{:}); | |
% there may be a better way to do this | |
combinations = cellfun(@(x) x(:), combinations, 'uniformoutput', false); | |
points = [combinations{:}]; | |
gd = zeros(length(points), 1); | |
for jj = 1:size(points, 1) | |
gd(jj, 1) = distances(points(jj, 1), points(jj, 2)); | |
end | |
sum_error += mean(gd); | |
end | |
current_error = sum_error / n_groups; | |
end | |
function [group_labels] = cluster_equal_groups(data, n_groups, n_members) | |
% equal-size clustering based on data exchanges between pairs of clusters | |
% data is n by d where d is the dimensionality | |
% group_labels match these n rows of data | |
dum = size(data); | |
n_samples = dum(1); | |
% no need to calculate n_members if it is given | |
distances = squareform(pdist(data)); | |
memberships = kmeans(data, n_groups); | |
display(memberships'); | |
current_err = average_data_distance_error(n_groups, memberships, distances); | |
i_try = 1; | |
while 1 | |
past_err = current_err; | |
for a = 1:n_samples | |
for b = 1:a | |
% exchange membership and check new error | |
[memberships(a), memberships(b)] = deal(memberships(b), memberships(a)); | |
test_err = average_data_distance_error(n_groups, memberships, distances); | |
printf("{%d}: {%d}<->{%d} E={%d} \n", i_try, a, b, test_err) | |
if test_err < current_err | |
current_err = test_err; | |
else | |
% put them back | |
[memberships(a), memberships(b)] = deal(memberships(b), memberships(a)); | |
end | |
end | |
end | |
if past_err == current_err | |
break | |
end | |
i_try = i_try + 1; | |
end | |
group_labels = memberships; | |
end | |
X = [[ 3.65717783 -2.35242688]; | |
[-8.76958635 8.91749004]; | |
[-0.28553716 2.63076825]; | |
[2.94257145 6.86308015]; | |
[-0.34432412 3.31421283]; | |
[ 4.35608493 -2.53772814]; | |
[5.20830874 6.57070778]; | |
[6.64292785 1.95343147]; | |
[6.04035664 1.35191952]; | |
[-0.13386303 3.53459815]; | |
[6.23856053 2.48924218]; | |
[ 2.92326089 -6.96343091]; | |
[5.39424751 0.16278302]; | |
[ 2.71939864 -6.34252507]; | |
[6.3648091 2.19435851]; | |
[5.89923934 1.82693813]; | |
[4.71247534 6.87301748]; | |
[-0.88150049 3.27236785]; | |
[ 4.00076041 -2.0599616 ]; | |
[ 4.64779685 -0.0072023 ]; | |
[ 3.93597157 -2.07219354]; | |
[7.44281667 5.00732413]; | |
[-7.81470103 8.98804792]; | |
[-7.91476811 8.48525329]; | |
[ 2.58578377 -7.90613058]; | |
[6.99105527 5.16860678]; | |
[4.77338784 1.42694333]; | |
[ 2.64651479 -7.06087589]; | |
[4.42166999 6.67668983]; | |
[7.15465099 5.50019161]; | |
[-8.14593159 9.46525138]; | |
[7.26294783 5.48510317]]; | |
labels = cluster_equal_groups(data=X, n_members=8, n_groups=4) | |
% It is helpful to have a unit test (of the error func) with some expected behaviour | |
%% unit testing %% | |
% X = [[-0.12260639 3.46400662]; [-8.17848874 9.18596168]; [-7.5051986 9.00890067]; [5.8667249 0.51736684]; [5.59627747 0.83136308]; [-1.01041072 3.34506264]]; | |
% distances = squareform(pdist(X)); | |
% n_groups = 3 | |
% memberships = kmeans(X, n_groups); | |
% average_data_distance_error(n_groups, memberships, distances) | |
%% unit testing %% | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment