Skip to content

Instantly share code, notes, and snippets.

@orenbenkiki
Last active June 4, 2020 09:37
Show Gist options
  • Save orenbenkiki/b131a4604442666616a309dfea1b3044 to your computer and use it in GitHub Desktop.
Save orenbenkiki/b131a4604442666616a309dfea1b3044 to your computer and use it in GitHub Desktop.
R code for visualization of similarity matrices, by reordering them so high values move to the diagonal (become "slanted"), and optionally apply order-preserving clustering to the result.
#' Compute rows and columns orders which move high values close to the diagonal.
#'
#' For a matrix expressing the cross-similarity between two (possibly different)
#' sets of entities, this produces better results than clustering (e.g. as done
#' by \code{pheatmap}). This is because clustering does not care about the order
#' of each two sub-partitions. That is, clustering is as happy with \code{((2, 1),
#' (4, 3))} as it is with the more sensible \code{((1, 2), (3, 4))}. As a result,
#' visualizations of similarities using naive clustering can be misleading.
#'
#' @param data A rectangular matrix.
#' @param order The default for whether to order rows and columns.
#' @param order_rows Whether to reorder the rows.
#' @param order_cols Whether to reorder the columns.
#' @param same_order Whether to apply the same order to both rows and columns.
#' @param max_spin_count How many times to retry improving the solution before giving up.
#' @return A list with two keys, \code{rows} and \code{cols}, which contain the order.
#'
#' @export
slanted_orders <- function(data, ..., order=T, order_rows=NULL, order_cols=NULL,
same_order=F, max_spin_count=10) {
if (is.null(order_rows)) { order_rows = order }
if (is.null(order_cols)) { order_cols = order }
wrapr::stop_if_dot_args(substitute(list(...)), 'slanted_orders')
rows_count <- dim(data)[1]
cols_count <- dim(data)[2]
rows_indices <- as.vector(1:rows_count)
cols_indices <- as.vector(1:cols_count)
rows_permutation <- rows_indices
cols_permutation <- cols_indices
if (same_order) {
stopifnot(rows_count == cols_count)
permutation <- rows_indices
}
if (order_rows || order_cols) {
squared_data <- data * data
epsilon <- min(squared_data[squared_data > 0]) / 10
reorder_phase <- function() {
spinning_rows_count <- 0
spinning_cols_count <- 0
spinning_same_count <- 0
was_changed <- T
error_rows <- NULL
error_cols <- NULL
error_same <- NULL
while (was_changed) {
was_changed <- F
ideal_index <- NULL
if (order_rows) {
sum_indexed_rows <- rowSums(sweep(squared_data, 2, cols_indices, `*`))
sum_squared_rows <- rowSums(squared_data)
ideal_row_index <- (sum_indexed_rows + epsilon) / (sum_squared_rows + epsilon)
if (same_order) {
ideal_index <- ideal_row_index
} else {
error <- ideal_row_index - rows_indices
new_error_rows <- sum(error * error)
new_rows_permutation <- order(ideal_row_index)
new_changed <- any(new_rows_permutation != rows_indices)
if (is.null(error_rows) || new_error_rows < error_rows) {
error_rows <- new_error_rows
spinning_rows_count <- 0
} else {
spinning_rows_count <- spinning_rows_count + 1
}
if (new_changed && spinning_rows_count < max_spin_count) {
was_changed <- T
squared_data <<- squared_data[new_rows_permutation,]
rows_permutation <<- rows_permutation[new_rows_permutation]
}
}
}
if (order_cols) {
sum_indexed_cols <- colSums(sweep(squared_data, 1, rows_indices, `*`))
sum_squared_cols <- colSums(squared_data)
ideal_col_index <- (sum_indexed_cols + epsilon) / (sum_squared_cols + epsilon)
if (same_order) {
if (!is.null(ideal_index)) {
ideal_index <- (ideal_index + ideal_col_index) / 2
} else {
ideal_index <- ideal_col_index
}
} else {
error <- ideal_col_index - cols_indices
new_error_cols <- sum(error * error)
new_cols_permutation <- order(ideal_col_index)
new_changed <- any(new_cols_permutation != cols_indices)
if (is.null(error_cols) || new_error_cols < error_cols) {
error_cols <- new_error_cols
spinning_cols_count <- 0
} else {
spinning_cols_count <- spinning_cols_count + 1
}
if (new_changed && spinning_cols_count < max_spin_count) {
was_changed <- T
squared_data <<- squared_data[,new_cols_permutation]
cols_permutation <<- cols_permutation[new_cols_permutation]
}
}
}
if (!is.null(ideal_index)) {
error <- ideal_index - rows_indices
new_error_same <- sum(error * error)
new_permutation <- order(ideal_index)
new_changed <- any(new_permutation != rows_indices)
if (is.null(error_same) || new_error_same < error_same) {
error_same <- new_error_same
spinning_same_count <- 0
} else {
spinning_same_count <- spinning_same_count + 1
}
if (new_changed && spinning_same_count < max_spin_count) {
was_changed <- T
squared_data <<- squared_data[new_permutation,new_permutation]
permutation <<- permutation[new_permutation]
rows_permutation <<- permutation
cols_permutation <<- permutation
}
}
}
}
discount_outliers <- function() {
row_indices_matrix <- matrix(rep(rows_indices, each=cols_count),
nrow=rows_count, ncol=cols_count, byrow=T)
col_indices_matrix <- matrix(rep(cols_indices, each=rows_count),
nrow=rows_count, ncol=cols_count, byrow=F)
rows_per_col <- rows_count / cols_count
cols_per_row <- cols_count / rows_count
ideal_row_indices_matrix <- col_indices_matrix * rows_per_col
ideal_col_indices_matrix <- row_indices_matrix * cols_per_row
row_distance_matrix <- row_indices_matrix - ideal_row_indices_matrix
col_distance_matrix <- col_indices_matrix - ideal_col_indices_matrix
weight_matrix <- (1 + abs(row_distance_matrix)) * (1 + abs(col_distance_matrix))
squared_data <<- squared_data / weight_matrix
}
reorder_phase()
discount_outliers()
reorder_phase()
}
return (list(rows=rows_permutation, cols=cols_permutation))
}
#' Reorder data rows and columns to move high values close to the diagonal.
#'
#' Given a matrix expressing the cross-similarity between two (possibly different)
#' sets of entities, this uses \code{slanted_orders} to compute the "best" order
#' for visualizing the matrix, then returns the reordered data. Commonly used in:
#' \code{pheatmap(slanted_reorder(data),cluster_rows=F,cluster_cols=F)}.
#'
#' @param data A rectangular matrix.
#' @param order The default for whether to order rows and columns.
#' @param order_rows Whether to reorder the rows.
#' @param order_cols Whether to reorder the columns.
#' @param same_order Whether to apply the same order to both rows and columns.
#' @return A matrix of the same shape whose rows and columns are a permutation of the input.
#'
#' @export
slanted_reorder <- function(data, ..., order=T, order_rows=NULL, order_cols=NULL, same_order=F) {
wrapr::stop_if_dot_args(substitute(list(...)), 'slanted_reorder')
orders <- slanted_orders(data,
order=order, order_rows=order_rows, order_cols=order_cols,
same_order=same_order)
return (data[orders$rows, orders$cols])
}
#' Cluster ordered data.
#'
#' Given a distance matrix for sorted objects, compute a hierarchical clustering preserving this
#' order. That is, this is similar to `hclust` with the constraint that the result's order is
#' always `1:N`. This can be applied to the results of `slanted_reorder`, to give a "plausible"
#' clustering for the data.
#'
#' @param dist A distances object (as created by `dist`).
#' @param method The method of computing the clusters. Valid values are `agglomerative`
#' (bottom-up, the default) or `divisive` (top-down).
#' @param aggregation How to aggregate distances between clusters; valid values are `mean`
#' (the default), `min` and `max`.
#' @return A clustering object (as created by `hclust`).
#'
#' @export
oclust <- function(dist, ...,
method=c('agglomerative', 'divisive'), aggregation=c('mean', 'min', 'max')) {
wrapr::stop_if_dot_args(substitute(list(...)), 'oclust')
aggregation <- switch(match.arg(aggregation), mean='mean', min='min', max='max')
method <- switch(match.arg(method), agglomerative='agglomerative', divisive='divisive')
distances <- as.matrix(dist)
stopifnot(dim(distances)[1] == dim(distances)[2])
rows_count <- dim(distances)[1]
if (method == 'agglomerative') {
hclust <- bottom_up(distances, aggregation)
} else {
hclust <- top_down(distances, aggregation)
}
hclust$labels <- rownames(data)
hclust$method <- sprintf('oclust.%s.%s', method, aggregation)
if (!is.null(attr(dist, 'method'))) {
hclust$dist.method <- attr(dist, 'method')
}
if (!is.null(attr(dist, 'Labels'))) {
hclust$labels <- attr(dist, 'Labels')
}
hclust$order <- 1:rows_count
class(hclust) <- 'hclust'
return (hclust)
}
# TODO: This can be made to run much faster.
bottom_up <- function(distances, aggregation) {
aggregate <- switch(aggregation, mean=mean, min=min, max=max)
rows_count <- dim(distances)[1]
diag(distances) <- Inf
merge <- matrix(0, nrow=rows_count - 1, ncol=2)
height <- rep(0, rows_count - 1)
merged_height <- rep(0, rows_count)
groups <- -(1:rows_count)
for (merge_index in 1:(rows_count - 1)) {
adjacent_distances <- pracma::Diag(distances, 1)
low_index <- which.min(adjacent_distances)
high_index <- low_index + 1
grouped_indices <- sort(groups[c(low_index, high_index)])
merged_indices <- which(groups %in% grouped_indices)
groups[merged_indices] <- merge_index
merge[merge_index,] <- grouped_indices
height[merge_index] <- max(merged_height[merged_indices]) + adjacent_distances[low_index]
merged_height[merged_indices] <- height[merge_index]
merged_distances <- apply(distances[,merged_indices], 1, aggregate)
distances[,merged_indices] <- merged_distances
distances[merged_indices,] <- rep(merged_distances, each=length(merged_indices))
distances[merged_indices, merged_indices] <- Inf
}
return (list(merge=merge, height=height))
}
top_down <- function(distances, aggregation) {
aggregate <- switch(aggregation, mean=cumsum, min=cummin, max=cummax)
rows_count <- dim(distances)[1]
merge <- matrix(0, nrow=rows_count - 1, ncol=2)
height <- rep(0, rows_count - 1)
accumulator <- list(merge=merge, height=height, merge_index=rows_count-1)
accumulator <- top_down_divide(accumulator, distances, aggregate, aggregation, 1:rows_count)
return (list(merge=accumulator$merge, height=accumulator$height))
}
top_down_divide <- function(accumulator, distances, aggregate, aggregation, indices_range) {
rows_count <- dim(distances)[1]
split_count <- length(indices_range)
stopifnot(split_count > 1)
effective_distances <- distances[indices_range, rev(indices_range)]
effective_distances <- apply(effective_distances, 2, aggregate)
effective_distances <- t(apply(t(effective_distances), 2, aggregate)) # TODO
effective_distances <- effective_distances[1:split_count, split_count:1]
candidate_distances <- pracma::Diag(effective_distances, 1)
if (aggregation == 'mean') {
candidate_distances <- candidate_distances / 1:length(candidate_distances)
candidate_distances <- candidate_distances / length(candidate_distances):1
}
split_position <- which.max(candidate_distances)
split_index <- split_position + min(indices_range) - 1
low_range <- min(indices_range):split_index
high_range <- (split_index + 1):max(indices_range)
stopifnot(length(low_range) < split_count)
stopifnot(length(high_range) < split_count)
merge_index <- accumulator$merge_index
if (length(low_range) == 1) {
low_index <- -min(low_range)
low_height <- 0
} else {
low_index <- accumulator$merge_index - 1
accumulator$merge_index <- low_index
accumulator <- top_down_divide(accumulator, distances, aggregate, aggregation, low_range)
low_height <- accumulator$height[low_index]
}
if (length(high_range) == 1) {
high_index <- -min(high_range)
high_height <- 0
} else {
high_index <- accumulator$merge_index - 1
accumulator$merge_index <- high_index
accumulator <- top_down_divide(accumulator, distances, aggregate, aggregation, high_range)
high_height <- accumulator$height[high_index]
}
accumulator$height[merge_index] <- candidate_distances[split_position] + max(low_height,
high_height)
accumulator$merge[merge_index,] <- sort(c(low_index, high_index))
return (accumulator)
}
#' Enhanced version of `dist`.
#'
#' This also allows using the correlations as the basis for the distances.
#' If the method is `pearson`, `kendall` or `spearman`, then the distances
#' will be `2 - cor(t(data), method)`. Otherwise, `dist` will be used.
enhanced_dist <- function(data, method) {
if (method == 'pearson' || method == 'kendall' || method == 'spearman') {
if (method == 'pearson' && exists('tgs_cor')) {
distances <- 2 - tgs_cor(t(data))
} else {
distances <- 2 - cor(t(data), method=method)
}
distances_attributes <- attributes(distances)
distances_attributes$method <- method
attributes(distances) <- distances_attributes
return (distances)
}
return (dist(data, method))
}
#' Plot a heatmap with values as close to the diagonal as possible.
#'
#' Given a matrix expressing the cross-similarity between two (possibly different) sets of
#' entities, this uses \code{slanted_reorder} to move the high values close to the diagonal, then
#' computes an order-preserving clustering for visualizing the matrix with a dendrogram tree, and
#' passes all this to `pheatmap`.
#'
#' @param data A rectangular matrix
#' @param annotation_row Optional data frame describing each row.
#' @param annotation_col Optional data frame describing each column.
#' @param order The default for whether to order rows and columns.
#' @param order_rows Whether to reorder the rows.
#' @param order_cols Whether to reorder the columns.
#' @param same_order Whether to apply the same order to both rows and columns.
#' @param cluster The default for whether to cluster rows and columns.
#' @param cluster_rows Whether to cluster the rows (specify `order_rows=F` if giving an `hclust`).
#' @param cluster_cols Whether to cluster the columns (specify `order_cols=F` if giving an `hclust`).
#' @param distance_function The function for computing distance matrices (by default, `enhanced_dist`).
#' @param clustering_distance The default method for computing distances (by default, `pearson`).
#' @param clustering_distance_rows The method for computing distances between rows.
#' @param clustering_distance_cols The method for computing distances between columns.
#' @param clustering_method The default method to use for clustering the ordered data (by default, `agglomerative`).
#' @param clustering_method_rows The method to use for clustering the ordered rows.
#' @param clustering_method_cols The method to use for clustering the ordered columns.
#' @param clustering_aggregation The default aggregation method of cluster distances (by default, `mean`).
#' @param clustering_method_rows How to aggregate distances of clusters of rows.
#' @param clustering_method_cols How to aggregate distances of clusters of columns.
#' @param ... Additional flags to pass to `pheatmap`.
#' @return Whatever `pheatmap` returns.
sheatmap <- function(data, ...,
annotation_col=NULL,
annotation_row=NULL,
order=T,
order_rows=NULL,
order_cols=NULL,
same_order=F,
cluster=F,
cluster_rows=NULL,
cluster_cols=NULL,
distance_function=enhanced_dist,
clustering_distance='pearson',
clustering_distance_rows=NULL,
clustering_distance_cols=NULL,
clustering_method='agglomerative',
clustering_method_rows=NULL,
clustering_method_cols=NULL,
clustering_aggregation='mean',
clustering_aggregation_rows=NULL,
clustering_aggregation_cols=NULL) {
if (is.null(cluster_rows)) { cluster_rows = cluster }
if (is.null(cluster_cols)) { cluster_cols = cluster }
if (is.null(clustering_distance_rows)) { clustering_distance_rows=clustering_distance }
if (is.null(clustering_method_rows)) { clustering_method_rows = clustering_method }
if (is.null(clustering_aggregation_rows)) { clustering_aggregation_rows = clustering_aggregation }
if (is.null(clustering_distance_cols)) { clustering_distance_cols=clustering_distance }
if (is.null(clustering_method_cols)) { clustering_method_cols = clustering_method }
if (is.null(clustering_aggregation_cols)) { clustering_aggregation_cols = clustering_aggregation }
orders <- diagonal_orders(data, order_rows=order_rows, order_cols=order_cols,
same_order=same_order)
data <- data[orders$rows, orders$cols]
if (!is.null(annotation_row)) {
annotation_row <- reorder_frame(annotation_row, orders$rows)
}
if (!is.null(annotation_col)) {
annotation_col <- reorder_frame(annotation_col, orders$cols)
}
if (cluster_rows) {
rows_distances <- distance_function(data, method=clustering_distance_rows)
cluster_rows <- oclust(rows_distances,
method=clustering_method_rows,
aggregation=clustering_aggregation_rows)
}
if (cluster_cols) {
cols_distances <- distance_function(t(data), method=clustering_distance_cols)
cluster_cols <- oclust(cols_distances,
method=clustering_method_cols,
aggregation=clustering_aggregation_cols)
}
return (pheatmap::pheatmap(data, annotation_row=annotation_row, annotation_col=annotation_col,
cluster_rows=cluster_rows, cluster_cols=cluster_cols, ...))
}
#' Reorder the rows of a frame.
#'
#' If you expect \code{data[order]} to just work, you haven't been using R for very long.
#' It is this sort of thing that makes me *hate* coding in R.
#'
#' @param frame A data frame to reorder the rows of.
#' @param order An array containing indices permutation to apply to the rows.
#' @return The data frame with the new row orders.
reorder_frame <- function(data, order) {
row_names <- rownames(data)
if (ncol(data) == 1) {
vec <- t(data[1])
data[1] <- vec[order]
} else {
data <- data[order,]
}
rownames(data) <- row_names[order]
return (data)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment