Last active
February 22, 2025 00:05
-
-
Save slowkow/5797728 to your computer and use it in GitHub Desktop.
Plot a horizontal bar plot and the lower triangle of a heatmap aligned at the base of the bars
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
barplot_heatmap.py | |
Kamil Slowikowski | |
April 4, 2014 | |
This module has a function for creating a horizontal bar plot with an adjacent | |
heatmap rotated 45 degrees to show the lower triangle of a correlation | |
matrix comparing all pairs of bars. | |
References: | |
http://stackoverflow.com/questions/12848581/is-there-a-way-to-rotate-a-matplotlib-plot-by-45-degrees | |
http://stackoverflow.com/questions/2982929/plotting-results-of-hierarchical-clustering-ontop-of-a-matrix-of-data-in-python | |
""" | |
import matplotlib as mp | |
import numpy as np | |
import pandas as pd | |
import pylab as pl | |
import scipy.cluster.hierarchy as sch | |
import string | |
def main(): | |
mp.rc("font", family="serif") | |
nrows = 5 | |
ncols = 10 | |
labels = np.random.choice(list(string.ascii_uppercase), | |
ncols, replace=False) | |
series = pd.Series(np.random.random(ncols) * 5.0, index=labels) | |
matrix = pd.DataFrame(np.random.random((nrows, ncols)), columns=labels) | |
outfile = "barplot_heatmap.png" | |
barplot(series, matrix, outfile) | |
def barplot(series, matrix, outfile, | |
figsize=(6, 6), fontsize=10, title=None): | |
"""Create a bar plot and place the lower triangle of a heatmap directly | |
adjacent so that the bases of the bars line up with the diagonal of the | |
heatmap. | |
Parameters | |
---------- | |
series : pandas.Series | |
The bar heights and labels. | |
matrix : pandas.DataFrame | |
A matrix where each column corresponds to a bar in the bar plot. | |
outfile : str | |
Full path to the output file. | |
figsize : (width, height) | |
fontsize : float | |
title : str | |
""" | |
# Create a figure. | |
fig = pl.figure(figsize=figsize) | |
# Axes for the heatmap triangle. | |
ax = fig.add_subplot(121, frame_on=False, aspect=2.0) | |
# Get the heatmap triangle's axes and the order of the clustered samples. | |
cax, order = heatmap_triangle(matrix, ax) | |
# Adjust spacing between the heatmap triangle and the barplot. | |
fig.subplots_adjust(wspace=-0.12, hspace=0, left=0, right=1) | |
# Axes for the barplot. | |
ax = fig.add_subplot(122, frame_on=False) | |
# Put gridlines beneath the bars. | |
ax.set_axisbelow(True) | |
# Order the bars by the clustering. | |
series = series.ix[order] | |
ax = series.plot(ax=ax, kind='barh', title=title, linewidth=0, | |
grid=False, color='grey') | |
# Set the font size for the y-axis labels. | |
ax.tick_params(axis='y', which='major', labelsize=fontsize) | |
# Grid lines. | |
ax.grid(b=True, which='major', axis='both', alpha=0.5) | |
# Tick marks for the x-axis. | |
ax.set_xticks(np.arange(0, round(series.max() + 1))) | |
# Put the y-axis marks on the right. | |
ax.yaxis.tick_right() | |
ax.yaxis.set_label_position('right') | |
# Adjust tick length. | |
ax.tick_params(length=0, axis='x') | |
ax.tick_params(length=0, axis='y') | |
# Labels. | |
ax.set_xlabel('') | |
ax.set_ylabel('') | |
# Save. | |
fig.savefig(outfile, bbox_inches='tight') | |
def heatmap_triangle(dataframe, axes): | |
"""Create a heatmap of the lower triangle of a pairwise correlation | |
matrix of all pairs of columns in the given dataframe. The heatmap | |
triangle is rotated 45 degrees clockwise and drawn on the given axes. | |
Parameters | |
---------- | |
dataframe : pandas.DataFrame | |
axes : matplotlib.axes.Axes | |
""" | |
N = dataframe.shape[1] | |
D = dataframe.corr(method='pearson') | |
# UPGMA clustering, but other methods are also available. | |
Z = sch.linkage(D, method='average') | |
R = sch.dendrogram(Z, no_plot=True) | |
cluster_order = R['leaves'] | |
D = D.ix[cluster_order, cluster_order] | |
# Get the lower triangle of the matrix. | |
C = np.tril(D) | |
# Mask the upper triangle. | |
C = np.ma.masked_array(C, C == 0) | |
# Set the diagonal to zero. | |
for i in range(N): | |
C[i, i] = 0 | |
# Transformation matrix for rotating the heatmap. | |
A = np.array([(y, x) for x in range(N, -1, -1) for y in range(N + 1)]) | |
t = np.array([[0.5, 1], [0.5, -1]]) | |
A = np.dot(A, t) | |
# -1.0 correlation is blue, 0.0 is white, 1.0 is red. | |
cmap = pl.cm.RdBu_r | |
norm = mp.colors.BoundaryNorm(np.linspace(-1, 1, 14), cmap.N) | |
# This MUST be before the call to pl.pcolormesh() to align properly. | |
axes.set_xticks([]) | |
axes.set_yticks([]) | |
# Plot the correlation heatmap triangle. | |
X = A[:, 1].reshape(N + 1, N + 1) | |
Y = A[:, 0].reshape(N + 1, N + 1) | |
caxes = pl.pcolormesh(X, Y, np.flipud(C), axes=axes, cmap=cmap, norm=norm) | |
# Remove the ticks and reset the x limit. | |
axes.set_xlim(right=0) | |
# Add a colorbar below the heatmap triangle. | |
cb = pl.colorbar(caxes, ax=axes, orientation='horizontal', shrink=0.5825, | |
fraction=0.05, pad=-0.035, ticks=np.linspace(-1, 1, 5), | |
use_gridspec=True) | |
cb.set_label("$\mathrm{Pearson's}\ r$") | |
return caxes, D.index | |
if __name__ == '__main__': | |
main() |
@slowkow Thanks for this function! A snippet of a shorter adaptation of the function, with some simple rotations is below. I hope this is useful for additional use cases.
If possible, I have a question regarding the usage of white space: As the single rotation case with t
values [[.5, 1], [.5, -1]]
expands the triangle and covers more white space by cutting it in half (last image/panel, and also your example above), I was wondering if you can list other rotations that would allow this for other orientations. In general, this would be very useful to cover some of the panels below.
import matplotlib.pyplot as plt
import numpy as np
def heatmap_triangle(C, axes, rotation=None, show_cbar=False):
N = len(C)
# Transformation matrix for rotating the heatmap.
A = np.array([(y, x) for x in range(N, -1, -1) for y in range(N + 1)])
a = .5
b = 1
if rotation is not None:
t = np.array(rotation)
A = np.dot(A, t)
cmap = plt.cm.coolwarm
axes.set_xticks([])
axes.set_yticks([])
X = A[:, 1].reshape(N + 1, N + 1)
Y = A[:, 0].reshape(N + 1, N + 1)
caxes = plt.pcolormesh(X, Y, np.flipud(C), axes=axes, cmap=cmap)
# Remove the ticks and reset the x limit.
axes.set_xlim(right=0)
# Add a colorbar below the heatmap triangle.
if show_cbar:
cb = plt.colorbar(caxes, ax=axes, orientation='horizontal', shrink=0.2825,
fraction=0.08, pad=-4.535, ticks=np.linspace(-1, 1, 5),
use_gridspec=True,)
cb.set_label("weight")
np.random.seed(500)
X = np.random.rand(10, 10, )
X = np.array(X)
X = np.tril(X)
X = np.ma.masked_array(X, X == 0)
r1 = [[-.5, -.25], [.5, -.25]]
# 315 degrees (expanded)
r2 = [[.5, 1], [.5, -1]]
# 45 degrees (alternative)
r3 = [[.5, -.25], [-.5, -.25]]
fig = plt.figure(figsize=[3, 3])
plt.subplot()
sns.heatmap(X, cmap=plt.cm.coolwarm, mask=C == 0)
fig = plt.figure(figsize=[20, 6])
n_rotations = 6
i = 0
# 45 degrees
ax = plt.subplot(1, n_rotations, i + 1, frameon=False)
x = np.flip(np.flip(X, axis=1), axis=0)
heatmap_triangle(x, ax, r1)
# 45 degrees (alternate rotation)
ax = plt.subplot(1, n_rotations, i + 2, frameon=False)
x = np.rot90(np.fliplr(X))
x = np.rot90(x)
x = np.rot90(x)
heatmap_triangle(x, ax, r3)
# 135 degrees
ax = plt.subplot(1, n_rotations, i + 3, frameon=False)
x = np.flip(np.flip(X, axis=1), axis=0)
x = np.rot90(x)
heatmap_triangle(x, ax, r1)
# 225 degrees
ax = plt.subplot(1, n_rotations, i + 4, frameon=False)
x = np.rot90(X)
x = np.rot90(x)
x = np.rot90(x)
x = np.rot90(x)
x = np.ma.masked_array(x, x == 0)
heatmap_triangle(x, ax, r1)
plt.show()
fig = plt.figure(figsize=[10, 3])
# 315 degrees
ax = plt.subplot(1, 2, 1, frameon=False)
x = np.rot90(X)
x = np.rot90(x)
x = np.rot90(x)
x = np.rot90(x)
x = np.rot90(x)
x = np.ma.masked_array(x, x == 0)
heatmap_triangle(x, ax, r1)
# 315 degrees (expanded)
ax = plt.subplot(1, 2, 2, frameon=False)
x = np.rot90(np.fliplr(X))
x = np.rot90(x)
x = np.rot90(x)
heatmap_triangle(x, ax, r2)
plt.show()
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@MdUmar-tech Sorry, but I don't understand your goals. Please consider creating a new post on Stackoverflow, where hundreds of people will see it, and someone might have time to help you. First, I would recommend that you try to make figures by yourself. Then, once you have something that is similar to your desired output, share with us:
If you can give an example of a figure that you like, then there's a better chance that someone might be able to help you.