Skip to content

Instantly share code, notes, and snippets.

@Kautenja
Last active December 2, 2020 21:55
Show Gist options
  • Save Kautenja/f9d6fd3d1dee631200bc11b8a46a76b7 to your computer and use it in GitHub Desktop.
Save Kautenja/f9d6fd3d1dee631200bc11b8a46a76b7 to your computer and use it in GitHub Desktop.
A quick method to generate a matplotlib heatmap from a pandas.DataFrame
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
def heatmap(data,
size_scale:int = 100,
figsize: tuple = (15, 12),
x_rotation: int = 270,
cmap = mpl.cm.RdBu
):
"""
Build a heatmap based on the given data frame (Pearson's correlation).
Args:
data: the dataframe to plot the cross-correlation of
size_scale: the scaling parameter for box sizes
figsize: the size of the figure
Returns:
None
Notes:
based on -
https://towardsdatascience.com/better-heatmaps-and-correlation-matrix-plots-in-python-41445d0f2bec
"""
# copy the data before mutating it
data = data.copy()
# change datetimes and timedelta to floating points
for column in data.select_dtypes(include=[np.datetime64, np.timedelta64]):
data[column] = data[column].apply(lambda x: x.value)
# calculate the correlation matrix
data = data.corr()
data = pd.melt(data.reset_index(), id_vars='index')
data.columns = ['x', 'y', 'value']
x = data['x']
y = data['y']
# the size is the absolut value (correlation is on [-1, 1])
size = data['value'].abs()
norm = (data['value'] + 1) / 2
fig, ax = plt.subplots(figsize=figsize)
# Mapping from column names to integer coordinates
x_labels = [v for v in sorted(x.unique())]
y_labels = [v for v in sorted(y.unique())]
x_to_num = {p[1]:p[0] for p in enumerate(x_labels)}
y_to_num = {p[1]:p[0] for p in enumerate(y_labels)}
im = ax.scatter(
x=x.map(x_to_num), # Use mapping for x
y=y.map(y_to_num), # Use mapping for y
s=size * size_scale, # Vector of square sizes, proportional to size parameter
marker='s', # Use square as scatterplot marker
c=norm.apply(cmap)
)
# Show column labels on the axes
ax.set_xticks([x_to_num[v] for v in x_labels])
ax.set_xticklabels(x_labels, rotation=270, horizontalalignment='right')
ax.set_yticks([y_to_num[v] for v in y_labels])
ax.set_yticklabels(y_labels)
# move the points from the center of a grid point to the center of a box
ax.grid(False, 'major')
ax.grid(True, 'minor')
# move the ticks to correspond with the values
ax.set_xticks([t + 0.5 for t in ax.get_xticks()], minor=True)
ax.set_yticks([t + 0.5 for t in ax.get_yticks()], minor=True)
# move the starting point of the x and y axis forward to remove
# extra spacing from shifting grid points
ax.set_xlim([-0.5, max([v for v in x_to_num.values()]) + 0.5])
ax.set_ylim([-0.5, max([v for v in y_to_num.values()]) + 0.5])
# add a color bar legend to the plot
bar = fig.colorbar(mpl.cm.ScalarMappable(cmap=cmap), ticks=[0, 0.25, 0.5, 0.75, 1])
bar.outline.set_edgecolor('grey')
bar.outline.set_linewidth(1)
bar.ax.set_yticklabels(['-1', '-0.5', '0', '0.5', '1'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment