Last active
January 18, 2024 09:04
-
-
Save YuanfengZhang/69205779257d34a1f6698a79339d086a to your computer and use it in GitHub Desktop.
Draw a heatmap for discrete values in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
A quick guide for beginners to draw a heatmap for discrete values. | |
""" | |
from collections import OrderedDict | |
from typing import Tuple | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from matplotlib.colors import LinearSegmentedColormap | |
from matplotlib.axes import Axes | |
from matplotlib.figure import Figure | |
def discrete_heatmap(df: pd.DataFrame, | |
x_column: str, | |
y_column: str, | |
value_column: str, | |
colormap: dict[str, str], | |
string_for_nan: str = 'NA') -> Tuple[Figure, Axes]: | |
"""Pivot the original dataframe, create int dataframe, draw the heatmap, add the annotation. | |
Args: | |
df (pd.DataFrame): the pandas dataframe containing the data u gonna plot. | |
The Df should look like this: | |
>>> df | |
country month most_liked | |
U.S. Dec Andy | |
Russian Dec Sasha | |
China Dec Quan | |
... ... ... | |
x_column (str): the column which u would like to show in the heatmap as x-axis. | |
y_column (str): the column which u would like to show in the heatmap as y-axis. | |
value_column (str): the column which will be used to fill the cells in heatmap. | |
colormap (dict[str, str]): all the unique discrete values and colors u wanna use in | |
the heatmap. The cmap should look like this: | |
>>> colormap | |
{'Andy': '#BBD844', 'Sasha': '#448CA2', 'Quan': '#E77FCA'} | |
If the nan values are not included in colormap, {'NA': '#DAD8D8'} will be used as default. | |
""" | |
# Pivot | |
_pivot: pd.DataFrame | |
try: | |
_pivot = df.pivot(index=y_column, | |
columns=x_column, | |
values=value_column).fillna(string_for_nan) | |
except ValueError: | |
_pivot = df.pivot_table(index=y_column, | |
columns=x_column, | |
values=value_column, | |
aggfunc='first').fillna(string_for_nan) | |
# Plot | |
_ordered_cmap: OrderedDict[str, str] | |
_index_dict: dict[str, int] | |
_fig: Figure | |
_ax: Axes | |
if 'NA' not in colormap.keys(): | |
_ordered_cmap = OrderedDict({**{string_for_nan: '#DAD8D8'}, | |
**colormap}) | |
else: | |
_ordered_cmap = OrderedDict(colormap) | |
_index_dict = {_k: _i for _i, _k in enumerate(_ordered_cmap)} | |
_fig = plt.figure(figsize=(8, 8), dpi=200) | |
_ax = sns.heatmap(data=_pivot, | |
cmap=LinearSegmentedColormap.from_list(name=value_column, | |
colors=list(_ordered_cmap.values()), | |
N=len(_ordered_cmap)), | |
annot=_pivot.map(func=lambda x: _index_dict[x]), | |
fmt='', linewidths=.5) | |
_colorbar = _ax.collections[0].colorbar | |
_colorbar.set_ticklabels(list(_ordered_cmap.keys())) | |
_ax.tick_params(axis='both', length=0) | |
return _fig, _ax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment