Last active
December 2, 2020 21:55
-
-
Save Kautenja/f9d6fd3d1dee631200bc11b8a46a76b7 to your computer and use it in GitHub Desktop.
A quick method to generate a matplotlib heatmap from a pandas.DataFrame
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
def heatmap(data, | |
size_scale:int = 100, | |
figsize: tuple = (15, 12), | |
x_rotation: int = 270, | |
cmap = mpl.cm.RdBu | |
): | |
""" | |
Build a heatmap based on the given data frame (Pearson's correlation). | |
Args: | |
data: the dataframe to plot the cross-correlation of | |
size_scale: the scaling parameter for box sizes | |
figsize: the size of the figure | |
Returns: | |
None | |
Notes: | |
based on - | |
https://towardsdatascience.com/better-heatmaps-and-correlation-matrix-plots-in-python-41445d0f2bec | |
""" | |
# copy the data before mutating it | |
data = data.copy() | |
# change datetimes and timedelta to floating points | |
for column in data.select_dtypes(include=[np.datetime64, np.timedelta64]): | |
data[column] = data[column].apply(lambda x: x.value) | |
# calculate the correlation matrix | |
data = data.corr() | |
data = pd.melt(data.reset_index(), id_vars='index') | |
data.columns = ['x', 'y', 'value'] | |
x = data['x'] | |
y = data['y'] | |
# the size is the absolut value (correlation is on [-1, 1]) | |
size = data['value'].abs() | |
norm = (data['value'] + 1) / 2 | |
fig, ax = plt.subplots(figsize=figsize) | |
# Mapping from column names to integer coordinates | |
x_labels = [v for v in sorted(x.unique())] | |
y_labels = [v for v in sorted(y.unique())] | |
x_to_num = {p[1]:p[0] for p in enumerate(x_labels)} | |
y_to_num = {p[1]:p[0] for p in enumerate(y_labels)} | |
im = ax.scatter( | |
x=x.map(x_to_num), # Use mapping for x | |
y=y.map(y_to_num), # Use mapping for y | |
s=size * size_scale, # Vector of square sizes, proportional to size parameter | |
marker='s', # Use square as scatterplot marker | |
c=norm.apply(cmap) | |
) | |
# Show column labels on the axes | |
ax.set_xticks([x_to_num[v] for v in x_labels]) | |
ax.set_xticklabels(x_labels, rotation=270, horizontalalignment='right') | |
ax.set_yticks([y_to_num[v] for v in y_labels]) | |
ax.set_yticklabels(y_labels) | |
# move the points from the center of a grid point to the center of a box | |
ax.grid(False, 'major') | |
ax.grid(True, 'minor') | |
# move the ticks to correspond with the values | |
ax.set_xticks([t + 0.5 for t in ax.get_xticks()], minor=True) | |
ax.set_yticks([t + 0.5 for t in ax.get_yticks()], minor=True) | |
# move the starting point of the x and y axis forward to remove | |
# extra spacing from shifting grid points | |
ax.set_xlim([-0.5, max([v for v in x_to_num.values()]) + 0.5]) | |
ax.set_ylim([-0.5, max([v for v in y_to_num.values()]) + 0.5]) | |
# add a color bar legend to the plot | |
bar = fig.colorbar(mpl.cm.ScalarMappable(cmap=cmap), ticks=[0, 0.25, 0.5, 0.75, 1]) | |
bar.outline.set_edgecolor('grey') | |
bar.outline.set_linewidth(1) | |
bar.ax.set_yticklabels(['-1', '-0.5', '0', '0.5', '1']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment