Skip to content

Instantly share code, notes, and snippets.

@darcyabjones
Created April 2, 2025 15:29
Show Gist options
  • Save darcyabjones/2a030aaa149dd00823af8fe2d746e4e3 to your computer and use it in GitHub Desktop.
Save darcyabjones/2a030aaa149dd00823af8fe2d746e4e3 to your computer and use it in GitHub Desktop.
Tools to help perform consensus multi-resolution clustering as described in https://doi.org/10.1101/2022.10.09.511493
# This is code to perform consensus clustering from a set of clusters from e.g. leiden clustering
# https://www.biorxiv.org/content/biorxiv/early/2022/10/11/2022.10.09.511493.full.pdf
# The basic use is as follows
#
#P1 <- factor(c(1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4))
#P2 <- factor(c(1, 1, 2, 1, 2, 1, 2, 2, 3, 3, 4, 4, 4, 4, 4, 4))
#P3 <- factor(c(1, 1, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 3, 4, 4))
#
#clustering_sets <- list("P1" = P1, "P2" = P2, "P3" = P3)
#R <- get_rcl_similarity(clustering_sets)
#Rtree <- cluster_rcl(R)
#heatmap(R, Rowv = as.dendrogram(Rtree), Colv = as.dendrogram(Rtree), scale = "none")
#
#tmp <- ssa(Rtree, c(6, 3, 1.5), min_size = 2, prefix = "C")
#tree <- tmp$tree
#clusterings <- tmp$clusterings
#rm(tmp)
#ape::plot.phylo(tree, show.node.label = TRUE, use.edge.length = FALSE)
onehot <- function(A) {
if (! is.factor(A)) {
warning("A is not a factor, you may get unexpected cluster ids")
A <- as.factor(A)
}
if (length(levels(A)) < 2) {
m <- matrix(1, ncol = 1, nrow = length(A))
} else {
m <- model.matrix(~ 0 + A)
}
colnames(m) <- levels(A)
return(m)
}
bincount <- function(B, base=2) { return((B %*% base ^ seq(0, ncol(B) - 1))[, 1]) }
find_gcs <- function(A, B) {
Aoh <- onehot(A)
Boh <- onehot(B)
return(find_gcs_oh(Aoh, Boh))
}
find_gcs_oh <- function(A, B) {
bc <- unname(bincount(cbind(A, B)))
return(as.factor(bc))
}
partition_discrepancy <- function(A, gcs) {
partition_discrepancy_oh(onehot(A), onehot(gcs))
}
partition_discrepancy_oh <- function(A, gcs) {
sum(apply(A, MARGIN = 2, FUN = function(p) {min(colSums(p * (1 - gcs)))}))
}
set_finding <- function(P, x, y) {
if (P[x] == P[y]) {
return(P[x])
} else {
return(vector(length = 0, mode = typeof(P)))
}
}
set_finding_oh <- function(P, x, y) {
return(colnames(P)[which(P[x, ] * P[y, ] == 1)])
}
scy <- function(A, B, gcs) {
Ainter <- t(A) %*% gcs
Binter <- t(B) %*% gcs
#print(Ainter)
#print(Binter)
Amax <- which.max(Ainter)
Bmax <- which.max(Binter)
if (Ainter[Amax] == 0 || Binter[Bmax] == 0) {return(0)}
interlen <- sum(gcs)
Alen <- sum(A[, Amax])
Blen <- sum(B[, Bmax])
return(interlen / min(Alen, Blen))
}
get_rcl_similarity <- function(clustering_sets) {
if (is.matrix(clustering_sets)) {
clustering_sets <- as.data.frame(clustering_sets)
}
clustering_sets_oh <- lapply(clustering_sets, onehot)
ncells <- length(clustering_sets[[1]])
R <- matrix(0, ncol = ncells, nrow = ncells)
for (i in seq_len(length(clustering_sets_oh))) {
Pi <- clustering_sets_oh[[i]]
namei <- names(clustering_sets_oh)[i]
for (j in seq_len(i - 1)) {
Pj <- clustering_sets_oh[[j]]
namej <- names(clustering_sets_oh)[j]
gcs <- onehot(find_gcs_oh(Pi, Pj))
for (k in seq_len(ncol(gcs))) {
gcsk <- gcs[, k]
consistency <- scy(Pi, Pj, gcsk)
xs <- which(as.logical(gcsk))
for (x in xs) {for (y in xs) {
R[x, y] <- R[x, y] + consistency
}}
}
}
}
return(R / length(clustering_sets_oh))
}
cluster_rcl <- function(R) {
#stopifnot(all(R <= 1))
Rd <- as.dist(1 - (R / max(R)))
hcl <- hclust(Rd, method = "single")
return(hcl)
}
lss <- function(tree) {
if (class(tree) != "phylo") {
tree <- ape::as.phylo(tree)
}
lsss <- integer(length = max(tree$edge))
sizes <- integer(length = max(tree$edge))
tree$node.label <- length(tree$tip.label) + seq_len(tree$Nnode)
tree <- ape::reorder.phylo(tree, order = "pruningwise")
clade_list <- split(tree$edge[, 2], tree$edge[, 1])
clade_order <- unique(tree$edge[, 1])
for (clade in clade_order) {
members <- clade_list[[as.name(clade)]]
stopifnot(length(members) == 2)
for (member in members) {
if (sizes[member] == 0) {
stopifnot(member <= length(tree$tip.label))
sizes[member] <- 1
}
}
minsize <- min(sizes[members])
maxlss <- max(lsss[members])
lsss[clade] <- max(maxlss, minsize)
sizes[clade] <- sum(sizes[members])
}
return(list(tree = tree, lss = lsss, size = sizes))
}
renumber_clusters <- function(clusterings, tree = NULL, pad = TRUE, prefix = NULL, alpha = FALSE) {
cl <- do.call(rbind, lapply(
unname(clusterings),
function(ti) {as.data.frame(table(ti))}
))
cl <- unique(cl)
cl <- cl[order(-cl$Freq, cl$ti), ]
cl$ti <- as.character(levels(cl$ti)[cl$ti])
id_map <- seq_len(nrow(cl))
if (alpha == "lower") {
stopifnot(nrow(cl) <= length(letters))
id_map <- letters[id_map]
} else if ((alpha == "upper") || alpha) {
stopifnot(nrow(cl) <= length(LETTERS))
id_map <- LETTERS[id_map]
} else if (pad) {
padn <- ceiling(log10(nrow(cl) + 1))
id_map <- sprintf(paste0("%0", padn, "d"), id_map)
} else {
id_map <- as.character(id_map)
}
if ((!is.null(prefix)) && (!is.na(prefix))) {
id_map <- paste0(prefix, id_map)
}
id_map <- as.list(id_map)
names(id_map) <- cl$ti
cl$new_id <- vapply(cl$ti, FUN.VALUE = "1", FUN = function(ti) {
if (is.na(ti)) {
return(as.character(NA))
} else {
return(id_map[[as.name(ti)]])
}
})
colnames(cl) <- c("old_id", "frequency", "new_id")
cl <- cl[, c("old_id", "new_id", "frequency")]
new_clusterings <- lapply(clusterings, function(rli) {
vapply(rli, FUN.VALUE = "1", FUN = function(ti) {
if (is.na(ti)) {
return(as.character(NA))
} else {
return(id_map[[as.name(ti)]])
}
})
})
out <- list(id_map = cl, clusterings = new_clusterings)
if (!is.null(tree)) {
tree$node.label <- vapply(tree$node.label, FUN.VALUE = "1", function(ti) {
if (as.character(ti) %in% names(id_map)) {
return(id_map[[as.name(ti)]])
} else {
return(as.character(NA))
}
})
out$tree <- tree
}
return(out)
}
ssa <- function(tree, resolutions, min_size = NULL, renumber = TRUE, pad = TRUE, prefix = NULL, alpha = FALSE) {
resolutions <- sort(resolutions, decreasing = TRUE)
if (is.null(min_size)) {
min_size <- max(1, min(resolutions))
}
tmp <- lss(tree)
tree <- tmp$tree
lsss <- tmp$lss
sizes <- tmp$size
rm(tmp)
tree <- ape::reorder.phylo(tree, order = "postorder")
clusterings <- list()
n <- length(tree$tip.label)
for (resolution in resolutions) {
r <- resolution / 2
clade_clusters <- integer(nrow(tree$edge) + 1)
clade_clusters <- clade_clusters - 1
clade_list <- split(tree$edge[, 2], tree$edge[, 1])
clade_order <- rev(unique(tree$edge[, 1]))
for (clade in clade_order) {
members <- clade_list[[as.name(clade)]]
stopifnot(length(members) == 2)
if (clade_clusters[clade] > 0) {
clade_clusters[members] <- clade_clusters[clade]
} else if (sizes[clade] < min_size) {
clade_clusters[clade] <- -1
clade_clusters[members] <- -1
} else if (all(lsss[members] <= r)) {
clade_clusters[clade] <- clade
clade_clusters[members] <- clade
} else {
clade_clusters[clade] <- 0
}
}
clade_clusters[clade_clusters < 0] <- NA
clusterings[[as.name(resolution)]] <- clade_clusters[seq_len(n)]
}
if (renumber) {
tmp <- renumber_clusters(clusterings, tree = tree, pad = pad, prefix = prefix, alpha = alpha)
tree <- tmp$tree
clusterings <-tmp$clusterings
}
clusterings <- do.call(cbind, clusterings)
return(list(clusterings = clusterings, tree = tree))
}
# This one is the other clustering method
simple_sim <- function(X) {
apply(X, MARGIN = 1, FUN = function(xi) {
Xout <- rowSums(t(xi == t(X)), na.rm = TRUE)
return(Xout / ncol(X))
})
}
recluster <- function(G, resolution, objective_function = "CPM", weights = NULL, beta = 0.01, n_iterations = 2, nseeds = 50, seed = NULL) {
if (! is.null(seed)) {
set.seed(seed)
}
run_seeds <- sample(1:2^15, nseeds)
M <- matrix(NA, nrow = length(G), ncol = nseeds)
for (i in seq_len(nseeds)) {
set.seed(run_seeds[i])
cl <- cluster_leiden(
G,
resolution_parameter = resolution,
objective_function = objective_function,
weights = weights,
beta = beta,
n_iterations = n_iterations
)
M[, i] <- membership(cl)
}
colnames(M) <- sprintf("R%f_C%d", resolution, seq_len(nseeds))
M <- as.data.frame(M)
return(M)
}
consensus_clust <- function(G, resolution, nseeds = 50, seed = NULL, max_iter = 5, tau = 0.2) {
# Method from https://doi.org/10.1038/srep00336
if (! is.null(seed)) {
set.seed(seed)
}
#print(paste("resolution", resolution))
M <- recluster(G, resolution, nseeds = nseeds)
cl_adj <- simple_sim(M)
cl_adj[cl_adj < tau] <- 0
i <- 1
while (!all((as.vector(cl_adj) == 0) | (as.vector(cl_adj) == 1))) {
#print(paste("resolution", resolution, i))
if (i > max_iter) {
warning("failed to converge")
break
} else {
i <- i + 1
}
G <- graph_from_adjacency_matrix(cl_adj, mode = "undirected", weighted = TRUE, diag = FALSE)
M <- recluster(G, resolution, nseeds = nseeds)
cl_adj <- simple_sim(M)
cl_adj[cl_adj < tau] <- 0
}
return(list(clusters = M[, 1], adjacency = cl_adj))
}
get_graph <- function(sce, dimred, k = 10, type = "rank", seed = 123) {
if (! is.null(seed)) {
set.seed(seed)
}
x <- reducedDim(sce, dimred)
x <- as.matrix(x)
G <- bluster::makeSNNGraph(x, k = k, type = type, BNPARAM = BiocNeighbors::KmknnParam(), BPPARAM = BiocParallel::SerialParam())
return(G)
}
multi_resolution_clust <- function(
G,
resolutions,
n = 10,
consensus = FALSE,
objective_function = "CPM",
weights = NULL,
beta = 0.01,
n_iterations = 2,
seed = 123
) {
set.seed(seed)
out <- list()
for (resolution in resolutions) {
if (consensus) {
cl <- consensus_clust(G, resolution, nseeds = n, max_iter = 5, tau = 0.2)
M <- data.frame(x = as.factor(cl$clusters))
colnames(M) <- sprintf("R%f_C1", resolution)
} else {
M <- recluster(G, resolution = resolution, nseeds = n)
M <- as.data.frame(lapply(M, factor))
}
out[[sprintf("R%f", resolution)]] <- M
}
return(do.call(cbind, unname(out)))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment