Created
October 21, 2021 19:15
-
-
Save Eleobert/6e3dd16f64f63cd927aa7c13da238922 to your computer and use it in GitHub Desktop.
Aglomerative clustering
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <armadillo> | |
#include <vector> | |
auto get_cluster_sim(const arma::mat& sim, const arma::uvec& cluster_a, const arma::uvec& cluster_b) | |
{ | |
arma::uvec combined_clusters = arma::join_cols(cluster_a, cluster_b); | |
arma::mat combined_sims = sim(combined_clusters, combined_clusters); | |
auto exemplar = combined_clusters(arma::mean(combined_sims).index_max()); | |
arma::vec exe_sims = sim.col(exemplar); | |
return (arma::mean(exe_sims(cluster_a)) + arma::mean(exe_sims(cluster_b))) / 2.0; | |
} | |
auto get_clusters_sim(const arma::mat& sim, const std::vector<arma::uvec>& clusters) | |
{ | |
arma::mat res(clusters.size(), clusters.size()); | |
res.fill(arma::datum::nan); | |
for(size_t i = 0; i < clusters.size(); i++) | |
{ | |
res(i, i) = -arma::datum::inf; | |
for(size_t j = i + 1; j < clusters.size(); j++) | |
{ | |
res(i, j) = get_cluster_sim(sim, clusters[i], clusters[j]); | |
res(j, i) = res(i, j); | |
} | |
} | |
return res; | |
} | |
auto get_index_max(const arma::mat& mat) | |
{ | |
auto index = arma::index_max(mat.as_col()); | |
return std::make_pair(index % mat.n_cols, index / mat.n_cols); | |
} | |
auto remove_clusters(std::vector<arma::uvec>& clusters, size_t idx1, size_t idx2) | |
{ | |
auto [idx_min, idx_max] = std::minmax(idx1, idx2); | |
clusters.erase(clusters.begin() + idx_max); | |
clusters.erase(clusters.begin() + idx_min); | |
} | |
auto remove_cluster_similarities(arma::mat& clusters_sim, size_t idx1, size_t idx2) | |
{ | |
auto [idx_min, idx_max] = std::minmax(idx1, idx2); | |
clusters_sim.shed_col(idx_max); | |
clusters_sim.shed_col(idx_min); | |
clusters_sim.shed_row(idx_max); | |
clusters_sim.shed_row(idx_min); | |
} | |
auto update_clusters_sim(const arma::mat& sim, arma::mat& clusters_sim, std::vector<arma::uvec>& clusters, | |
const arma::uvec& new_cluster) | |
{ | |
arma::vec dumb_vec(clusters_sim.n_cols); | |
dumb_vec.fill(arma::datum::nan); | |
clusters_sim = arma::join_cols(clusters_sim, dumb_vec.t()); | |
dumb_vec = arma::vec(clusters_sim.n_cols + 1); | |
dumb_vec.fill(arma::datum::nan); | |
clusters_sim = arma::join_rows(clusters_sim, dumb_vec); | |
auto j = clusters_sim.n_cols - 1; | |
for(size_t i = 0; i < clusters_sim.n_rows - 1; i++) | |
{ | |
clusters_sim(i, j) = get_cluster_sim(sim, clusters[i], new_cluster); | |
clusters_sim(j, i) = clusters_sim(i, j); | |
} | |
clusters_sim(j, j) = -arma::datum::inf; | |
} | |
auto agcluster(const arma::mat& sim, std::vector<arma::uvec> clusters, float cut_height) | |
{ | |
arma::mat clusters_sim = get_clusters_sim(sim, clusters); | |
while(true) | |
{ | |
auto [i_max, j_max] = get_index_max(clusters_sim); | |
auto height = clusters_sim(i_max, j_max); | |
if(height < cut_height) | |
{ | |
return clusters; | |
} | |
arma::uvec new_cluster = arma::join_cols(clusters[i_max], clusters[j_max]); | |
// remove the old clusters and insert the new one | |
remove_clusters(clusters, i_max, j_max); | |
clusters.emplace_back(new_cluster); | |
if(clusters.size() == 1) | |
{ | |
return clusters; | |
} | |
remove_cluster_similarities(clusters_sim, i_max, j_max); | |
// add the new cluster similarities | |
update_clusters_sim(sim, clusters_sim, clusters, new_cluster); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment