Last active
June 4, 2020 09:37
-
-
Save orenbenkiki/b131a4604442666616a309dfea1b3044 to your computer and use it in GitHub Desktop.
R code for visualization of similarity matrices, by reordering them so high values move to the diagonal (become "slanted"), and optionally apply order-preserving clustering to the result.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Compute rows and columns orders which move high values close to the diagonal. | |
#' | |
#' For a matrix expressing the cross-similarity between two (possibly different) | |
#' sets of entities, this produces better results than clustering (e.g. as done | |
#' by \code{pheatmap}). This is because clustering does not care about the order | |
#' of each two sub-partitions. That is, clustering is as happy with \code{((2, 1), | |
#' (4, 3))} as it is with the more sensible \code{((1, 2), (3, 4))}. As a result, | |
#' visualizations of similarities using naive clustering can be misleading. | |
#' | |
#' @param data A rectangular matrix. | |
#' @param order The default for whether to order rows and columns. | |
#' @param order_rows Whether to reorder the rows. | |
#' @param order_cols Whether to reorder the columns. | |
#' @param same_order Whether to apply the same order to both rows and columns. | |
#' @param max_spin_count How many times to retry improving the solution before giving up. | |
#' @return A list with two keys, \code{rows} and \code{cols}, which contain the order. | |
#' | |
#' @export | |
slanted_orders <- function(data, ..., order=T, order_rows=NULL, order_cols=NULL, | |
same_order=F, max_spin_count=10) { | |
if (is.null(order_rows)) { order_rows = order } | |
if (is.null(order_cols)) { order_cols = order } | |
wrapr::stop_if_dot_args(substitute(list(...)), 'slanted_orders') | |
rows_count <- dim(data)[1] | |
cols_count <- dim(data)[2] | |
rows_indices <- as.vector(1:rows_count) | |
cols_indices <- as.vector(1:cols_count) | |
rows_permutation <- rows_indices | |
cols_permutation <- cols_indices | |
if (same_order) { | |
stopifnot(rows_count == cols_count) | |
permutation <- rows_indices | |
} | |
if (order_rows || order_cols) { | |
squared_data <- data * data | |
epsilon <- min(squared_data[squared_data > 0]) / 10 | |
reorder_phase <- function() { | |
spinning_rows_count <- 0 | |
spinning_cols_count <- 0 | |
spinning_same_count <- 0 | |
was_changed <- T | |
error_rows <- NULL | |
error_cols <- NULL | |
error_same <- NULL | |
while (was_changed) { | |
was_changed <- F | |
ideal_index <- NULL | |
if (order_rows) { | |
sum_indexed_rows <- rowSums(sweep(squared_data, 2, cols_indices, `*`)) | |
sum_squared_rows <- rowSums(squared_data) | |
ideal_row_index <- (sum_indexed_rows + epsilon) / (sum_squared_rows + epsilon) | |
if (same_order) { | |
ideal_index <- ideal_row_index | |
} else { | |
error <- ideal_row_index - rows_indices | |
new_error_rows <- sum(error * error) | |
new_rows_permutation <- order(ideal_row_index) | |
new_changed <- any(new_rows_permutation != rows_indices) | |
if (is.null(error_rows) || new_error_rows < error_rows) { | |
error_rows <- new_error_rows | |
spinning_rows_count <- 0 | |
} else { | |
spinning_rows_count <- spinning_rows_count + 1 | |
} | |
if (new_changed && spinning_rows_count < max_spin_count) { | |
was_changed <- T | |
squared_data <<- squared_data[new_rows_permutation,] | |
rows_permutation <<- rows_permutation[new_rows_permutation] | |
} | |
} | |
} | |
if (order_cols) { | |
sum_indexed_cols <- colSums(sweep(squared_data, 1, rows_indices, `*`)) | |
sum_squared_cols <- colSums(squared_data) | |
ideal_col_index <- (sum_indexed_cols + epsilon) / (sum_squared_cols + epsilon) | |
if (same_order) { | |
if (!is.null(ideal_index)) { | |
ideal_index <- (ideal_index + ideal_col_index) / 2 | |
} else { | |
ideal_index <- ideal_col_index | |
} | |
} else { | |
error <- ideal_col_index - cols_indices | |
new_error_cols <- sum(error * error) | |
new_cols_permutation <- order(ideal_col_index) | |
new_changed <- any(new_cols_permutation != cols_indices) | |
if (is.null(error_cols) || new_error_cols < error_cols) { | |
error_cols <- new_error_cols | |
spinning_cols_count <- 0 | |
} else { | |
spinning_cols_count <- spinning_cols_count + 1 | |
} | |
if (new_changed && spinning_cols_count < max_spin_count) { | |
was_changed <- T | |
squared_data <<- squared_data[,new_cols_permutation] | |
cols_permutation <<- cols_permutation[new_cols_permutation] | |
} | |
} | |
} | |
if (!is.null(ideal_index)) { | |
error <- ideal_index - rows_indices | |
new_error_same <- sum(error * error) | |
new_permutation <- order(ideal_index) | |
new_changed <- any(new_permutation != rows_indices) | |
if (is.null(error_same) || new_error_same < error_same) { | |
error_same <- new_error_same | |
spinning_same_count <- 0 | |
} else { | |
spinning_same_count <- spinning_same_count + 1 | |
} | |
if (new_changed && spinning_same_count < max_spin_count) { | |
was_changed <- T | |
squared_data <<- squared_data[new_permutation,new_permutation] | |
permutation <<- permutation[new_permutation] | |
rows_permutation <<- permutation | |
cols_permutation <<- permutation | |
} | |
} | |
} | |
} | |
discount_outliers <- function() { | |
row_indices_matrix <- matrix(rep(rows_indices, each=cols_count), | |
nrow=rows_count, ncol=cols_count, byrow=T) | |
col_indices_matrix <- matrix(rep(cols_indices, each=rows_count), | |
nrow=rows_count, ncol=cols_count, byrow=F) | |
rows_per_col <- rows_count / cols_count | |
cols_per_row <- cols_count / rows_count | |
ideal_row_indices_matrix <- col_indices_matrix * rows_per_col | |
ideal_col_indices_matrix <- row_indices_matrix * cols_per_row | |
row_distance_matrix <- row_indices_matrix - ideal_row_indices_matrix | |
col_distance_matrix <- col_indices_matrix - ideal_col_indices_matrix | |
weight_matrix <- (1 + abs(row_distance_matrix)) * (1 + abs(col_distance_matrix)) | |
squared_data <<- squared_data / weight_matrix | |
} | |
reorder_phase() | |
discount_outliers() | |
reorder_phase() | |
} | |
return (list(rows=rows_permutation, cols=cols_permutation)) | |
} | |
#' Reorder data rows and columns to move high values close to the diagonal. | |
#' | |
#' Given a matrix expressing the cross-similarity between two (possibly different) | |
#' sets of entities, this uses \code{slanted_orders} to compute the "best" order | |
#' for visualizing the matrix, then returns the reordered data. Commonly used in: | |
#' \code{pheatmap(slanted_reorder(data),cluster_rows=F,cluster_cols=F)}. | |
#' | |
#' @param data A rectangular matrix. | |
#' @param order The default for whether to order rows and columns. | |
#' @param order_rows Whether to reorder the rows. | |
#' @param order_cols Whether to reorder the columns. | |
#' @param same_order Whether to apply the same order to both rows and columns. | |
#' @return A matrix of the same shape whose rows and columns are a permutation of the input. | |
#' | |
#' @export | |
slanted_reorder <- function(data, ..., order=T, order_rows=NULL, order_cols=NULL, same_order=F) { | |
wrapr::stop_if_dot_args(substitute(list(...)), 'slanted_reorder') | |
orders <- slanted_orders(data, | |
order=order, order_rows=order_rows, order_cols=order_cols, | |
same_order=same_order) | |
return (data[orders$rows, orders$cols]) | |
} | |
#' Cluster ordered data. | |
#' | |
#' Given a distance matrix for sorted objects, compute a hierarchical clustering preserving this | |
#' order. That is, this is similar to `hclust` with the constraint that the result's order is | |
#' always `1:N`. This can be applied to the results of `slanted_reorder`, to give a "plausible" | |
#' clustering for the data. | |
#' | |
#' @param dist A distances object (as created by `dist`). | |
#' @param method The method of computing the clusters. Valid values are `agglomerative` | |
#' (bottom-up, the default) or `divisive` (top-down). | |
#' @param aggregation How to aggregate distances between clusters; valid values are `mean` | |
#' (the default), `min` and `max`. | |
#' @return A clustering object (as created by `hclust`). | |
#' | |
#' @export | |
oclust <- function(dist, ..., | |
method=c('agglomerative', 'divisive'), aggregation=c('mean', 'min', 'max')) { | |
wrapr::stop_if_dot_args(substitute(list(...)), 'oclust') | |
aggregation <- switch(match.arg(aggregation), mean='mean', min='min', max='max') | |
method <- switch(match.arg(method), agglomerative='agglomerative', divisive='divisive') | |
distances <- as.matrix(dist) | |
stopifnot(dim(distances)[1] == dim(distances)[2]) | |
rows_count <- dim(distances)[1] | |
if (method == 'agglomerative') { | |
hclust <- bottom_up(distances, aggregation) | |
} else { | |
hclust <- top_down(distances, aggregation) | |
} | |
hclust$labels <- rownames(data) | |
hclust$method <- sprintf('oclust.%s.%s', method, aggregation) | |
if (!is.null(attr(dist, 'method'))) { | |
hclust$dist.method <- attr(dist, 'method') | |
} | |
if (!is.null(attr(dist, 'Labels'))) { | |
hclust$labels <- attr(dist, 'Labels') | |
} | |
hclust$order <- 1:rows_count | |
class(hclust) <- 'hclust' | |
return (hclust) | |
} | |
# TODO: This can be made to run much faster. | |
bottom_up <- function(distances, aggregation) { | |
aggregate <- switch(aggregation, mean=mean, min=min, max=max) | |
rows_count <- dim(distances)[1] | |
diag(distances) <- Inf | |
merge <- matrix(0, nrow=rows_count - 1, ncol=2) | |
height <- rep(0, rows_count - 1) | |
merged_height <- rep(0, rows_count) | |
groups <- -(1:rows_count) | |
for (merge_index in 1:(rows_count - 1)) { | |
adjacent_distances <- pracma::Diag(distances, 1) | |
low_index <- which.min(adjacent_distances) | |
high_index <- low_index + 1 | |
grouped_indices <- sort(groups[c(low_index, high_index)]) | |
merged_indices <- which(groups %in% grouped_indices) | |
groups[merged_indices] <- merge_index | |
merge[merge_index,] <- grouped_indices | |
height[merge_index] <- max(merged_height[merged_indices]) + adjacent_distances[low_index] | |
merged_height[merged_indices] <- height[merge_index] | |
merged_distances <- apply(distances[,merged_indices], 1, aggregate) | |
distances[,merged_indices] <- merged_distances | |
distances[merged_indices,] <- rep(merged_distances, each=length(merged_indices)) | |
distances[merged_indices, merged_indices] <- Inf | |
} | |
return (list(merge=merge, height=height)) | |
} | |
top_down <- function(distances, aggregation) { | |
aggregate <- switch(aggregation, mean=cumsum, min=cummin, max=cummax) | |
rows_count <- dim(distances)[1] | |
merge <- matrix(0, nrow=rows_count - 1, ncol=2) | |
height <- rep(0, rows_count - 1) | |
accumulator <- list(merge=merge, height=height, merge_index=rows_count-1) | |
accumulator <- top_down_divide(accumulator, distances, aggregate, aggregation, 1:rows_count) | |
return (list(merge=accumulator$merge, height=accumulator$height)) | |
} | |
top_down_divide <- function(accumulator, distances, aggregate, aggregation, indices_range) { | |
rows_count <- dim(distances)[1] | |
split_count <- length(indices_range) | |
stopifnot(split_count > 1) | |
effective_distances <- distances[indices_range, rev(indices_range)] | |
effective_distances <- apply(effective_distances, 2, aggregate) | |
effective_distances <- t(apply(t(effective_distances), 2, aggregate)) # TODO | |
effective_distances <- effective_distances[1:split_count, split_count:1] | |
candidate_distances <- pracma::Diag(effective_distances, 1) | |
if (aggregation == 'mean') { | |
candidate_distances <- candidate_distances / 1:length(candidate_distances) | |
candidate_distances <- candidate_distances / length(candidate_distances):1 | |
} | |
split_position <- which.max(candidate_distances) | |
split_index <- split_position + min(indices_range) - 1 | |
low_range <- min(indices_range):split_index | |
high_range <- (split_index + 1):max(indices_range) | |
stopifnot(length(low_range) < split_count) | |
stopifnot(length(high_range) < split_count) | |
merge_index <- accumulator$merge_index | |
if (length(low_range) == 1) { | |
low_index <- -min(low_range) | |
low_height <- 0 | |
} else { | |
low_index <- accumulator$merge_index - 1 | |
accumulator$merge_index <- low_index | |
accumulator <- top_down_divide(accumulator, distances, aggregate, aggregation, low_range) | |
low_height <- accumulator$height[low_index] | |
} | |
if (length(high_range) == 1) { | |
high_index <- -min(high_range) | |
high_height <- 0 | |
} else { | |
high_index <- accumulator$merge_index - 1 | |
accumulator$merge_index <- high_index | |
accumulator <- top_down_divide(accumulator, distances, aggregate, aggregation, high_range) | |
high_height <- accumulator$height[high_index] | |
} | |
accumulator$height[merge_index] <- candidate_distances[split_position] + max(low_height, | |
high_height) | |
accumulator$merge[merge_index,] <- sort(c(low_index, high_index)) | |
return (accumulator) | |
} | |
#' Enhanced version of `dist`. | |
#' | |
#' This also allows using the correlations as the basis for the distances. | |
#' If the method is `pearson`, `kendall` or `spearman`, then the distances | |
#' will be `2 - cor(t(data), method)`. Otherwise, `dist` will be used. | |
enhanced_dist <- function(data, method) { | |
if (method == 'pearson' || method == 'kendall' || method == 'spearman') { | |
if (method == 'pearson' && exists('tgs_cor')) { | |
distances <- 2 - tgs_cor(t(data)) | |
} else { | |
distances <- 2 - cor(t(data), method=method) | |
} | |
distances_attributes <- attributes(distances) | |
distances_attributes$method <- method | |
attributes(distances) <- distances_attributes | |
return (distances) | |
} | |
return (dist(data, method)) | |
} | |
#' Plot a heatmap with values as close to the diagonal as possible. | |
#' | |
#' Given a matrix expressing the cross-similarity between two (possibly different) sets of | |
#' entities, this uses \code{slanted_reorder} to move the high values close to the diagonal, then | |
#' computes an order-preserving clustering for visualizing the matrix with a dendrogram tree, and | |
#' passes all this to `pheatmap`. | |
#' | |
#' @param data A rectangular matrix | |
#' @param annotation_row Optional data frame describing each row. | |
#' @param annotation_col Optional data frame describing each column. | |
#' @param order The default for whether to order rows and columns. | |
#' @param order_rows Whether to reorder the rows. | |
#' @param order_cols Whether to reorder the columns. | |
#' @param same_order Whether to apply the same order to both rows and columns. | |
#' @param cluster The default for whether to cluster rows and columns. | |
#' @param cluster_rows Whether to cluster the rows (specify `order_rows=F` if giving an `hclust`). | |
#' @param cluster_cols Whether to cluster the columns (specify `order_cols=F` if giving an `hclust`). | |
#' @param distance_function The function for computing distance matrices (by default, `enhanced_dist`). | |
#' @param clustering_distance The default method for computing distances (by default, `pearson`). | |
#' @param clustering_distance_rows The method for computing distances between rows. | |
#' @param clustering_distance_cols The method for computing distances between columns. | |
#' @param clustering_method The default method to use for clustering the ordered data (by default, `agglomerative`). | |
#' @param clustering_method_rows The method to use for clustering the ordered rows. | |
#' @param clustering_method_cols The method to use for clustering the ordered columns. | |
#' @param clustering_aggregation The default aggregation method of cluster distances (by default, `mean`). | |
#' @param clustering_method_rows How to aggregate distances of clusters of rows. | |
#' @param clustering_method_cols How to aggregate distances of clusters of columns. | |
#' @param ... Additional flags to pass to `pheatmap`. | |
#' @return Whatever `pheatmap` returns. | |
sheatmap <- function(data, ..., | |
annotation_col=NULL, | |
annotation_row=NULL, | |
order=T, | |
order_rows=NULL, | |
order_cols=NULL, | |
same_order=F, | |
cluster=F, | |
cluster_rows=NULL, | |
cluster_cols=NULL, | |
distance_function=enhanced_dist, | |
clustering_distance='pearson', | |
clustering_distance_rows=NULL, | |
clustering_distance_cols=NULL, | |
clustering_method='agglomerative', | |
clustering_method_rows=NULL, | |
clustering_method_cols=NULL, | |
clustering_aggregation='mean', | |
clustering_aggregation_rows=NULL, | |
clustering_aggregation_cols=NULL) { | |
if (is.null(cluster_rows)) { cluster_rows = cluster } | |
if (is.null(cluster_cols)) { cluster_cols = cluster } | |
if (is.null(clustering_distance_rows)) { clustering_distance_rows=clustering_distance } | |
if (is.null(clustering_method_rows)) { clustering_method_rows = clustering_method } | |
if (is.null(clustering_aggregation_rows)) { clustering_aggregation_rows = clustering_aggregation } | |
if (is.null(clustering_distance_cols)) { clustering_distance_cols=clustering_distance } | |
if (is.null(clustering_method_cols)) { clustering_method_cols = clustering_method } | |
if (is.null(clustering_aggregation_cols)) { clustering_aggregation_cols = clustering_aggregation } | |
orders <- diagonal_orders(data, order_rows=order_rows, order_cols=order_cols, | |
same_order=same_order) | |
data <- data[orders$rows, orders$cols] | |
if (!is.null(annotation_row)) { | |
annotation_row <- reorder_frame(annotation_row, orders$rows) | |
} | |
if (!is.null(annotation_col)) { | |
annotation_col <- reorder_frame(annotation_col, orders$cols) | |
} | |
if (cluster_rows) { | |
rows_distances <- distance_function(data, method=clustering_distance_rows) | |
cluster_rows <- oclust(rows_distances, | |
method=clustering_method_rows, | |
aggregation=clustering_aggregation_rows) | |
} | |
if (cluster_cols) { | |
cols_distances <- distance_function(t(data), method=clustering_distance_cols) | |
cluster_cols <- oclust(cols_distances, | |
method=clustering_method_cols, | |
aggregation=clustering_aggregation_cols) | |
} | |
return (pheatmap::pheatmap(data, annotation_row=annotation_row, annotation_col=annotation_col, | |
cluster_rows=cluster_rows, cluster_cols=cluster_cols, ...)) | |
} | |
#' Reorder the rows of a frame. | |
#' | |
#' If you expect \code{data[order]} to just work, you haven't been using R for very long. | |
#' It is this sort of thing that makes me *hate* coding in R. | |
#' | |
#' @param frame A data frame to reorder the rows of. | |
#' @param order An array containing indices permutation to apply to the rows. | |
#' @return The data frame with the new row orders. | |
reorder_frame <- function(data, order) { | |
row_names <- rownames(data) | |
if (ncol(data) == 1) { | |
vec <- t(data[1]) | |
data[1] <- vec[order] | |
} else { | |
data <- data[order,] | |
} | |
rownames(data) <- row_names[order] | |
return (data) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment