Skip to content

Instantly share code, notes, and snippets.

@korymath
Last active October 3, 2019 22:18
Show Gist options
  • Save korymath/86ac78e712d3c49b791dc81da31ab1c2 to your computer and use it in GitHub Desktop.
Save korymath/86ac78e712d3c49b791dc81da31ab1c2 to your computer and use it in GitHub Desktop.
function [current_error] = average_data_distance_error(n_groups, memberships, distances)
% Calculate average distance between data in clusters.
sum_error = 0;
for i_group = 1:n_groups
% indices of datapoints belonging to group i_group
i = find(memberships == i_group);
% cell array with N vectors to combine
elements = {i, i};
% set up the result
combinations = cell(1, numel(elements));
[combinations{:}] = ndgrid(elements{:});
% there may be a better way to do this
combinations = cellfun(@(x) x(:), combinations, 'uniformoutput', false);
points = [combinations{:}];
gd = zeros(length(points), 1);
for jj = 1:size(points, 1)
gd(jj, 1) = distances(points(jj, 1), points(jj, 2));
end
sum_error += mean(gd);
end
current_error = sum_error / n_groups;
end
function [group_labels] = cluster_equal_groups(data, n_groups, n_members)
% equal-size clustering based on data exchanges between pairs of clusters
% data is n by d where d is the dimensionality
% group_labels match these n rows of data
dum = size(data);
n_samples = dum(1);
% no need to calculate n_members if it is given
distances = squareform(pdist(data));
memberships = kmeans(data, n_groups);
display(memberships');
current_err = average_data_distance_error(n_groups, memberships, distances);
i_try = 1;
while 1
past_err = current_err;
for a = 1:n_samples
for b = 1:a
% exchange membership and check new error
[memberships(a), memberships(b)] = deal(memberships(b), memberships(a));
test_err = average_data_distance_error(n_groups, memberships, distances);
printf("{%d}: {%d}<->{%d} E={%d} \n", i_try, a, b, test_err)
if test_err < current_err
current_err = test_err;
else
% put them back
[memberships(a), memberships(b)] = deal(memberships(b), memberships(a));
end
end
end
if past_err == current_err
break
end
i_try = i_try + 1;
end
group_labels = memberships;
end
X = [[ 3.65717783 -2.35242688];
[-8.76958635 8.91749004];
[-0.28553716 2.63076825];
[2.94257145 6.86308015];
[-0.34432412 3.31421283];
[ 4.35608493 -2.53772814];
[5.20830874 6.57070778];
[6.64292785 1.95343147];
[6.04035664 1.35191952];
[-0.13386303 3.53459815];
[6.23856053 2.48924218];
[ 2.92326089 -6.96343091];
[5.39424751 0.16278302];
[ 2.71939864 -6.34252507];
[6.3648091 2.19435851];
[5.89923934 1.82693813];
[4.71247534 6.87301748];
[-0.88150049 3.27236785];
[ 4.00076041 -2.0599616 ];
[ 4.64779685 -0.0072023 ];
[ 3.93597157 -2.07219354];
[7.44281667 5.00732413];
[-7.81470103 8.98804792];
[-7.91476811 8.48525329];
[ 2.58578377 -7.90613058];
[6.99105527 5.16860678];
[4.77338784 1.42694333];
[ 2.64651479 -7.06087589];
[4.42166999 6.67668983];
[7.15465099 5.50019161];
[-8.14593159 9.46525138];
[7.26294783 5.48510317]];
labels = cluster_equal_groups(data=X, n_members=8, n_groups=4)
% It is helpful to have a unit test (of the error func) with some expected behaviour
%% unit testing %%
% X = [[-0.12260639 3.46400662]; [-8.17848874 9.18596168]; [-7.5051986 9.00890067]; [5.8667249 0.51736684]; [5.59627747 0.83136308]; [-1.01041072 3.34506264]];
% distances = squareform(pdist(X));
% n_groups = 3
% memberships = kmeans(X, n_groups);
% average_data_distance_error(n_groups, memberships, distances)
%% unit testing %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment