Last active
March 2, 2022 15:22
-
-
Save vikramsoni2/9bea4dddd7af2af52538fc832db521a3 to your computer and use it in GitHub Desktop.
parallel coordinates plot using matplotlib
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
from matplotlib.path import Path | |
import matplotlib.pyplot as plt | |
import matplotlib.patches as patches | |
import seaborn as sns | |
from typing import Union, List, Tuple | |
def parallel_coordinates(df: pd.DataFrame, | |
target: str, | |
fillna: float = -1, | |
reverse: Union[List[int], None] = None, | |
figsize: Tuple[int, int] = (15,8) | |
) -> None: | |
"""Parallel coordinates plot allows to compare the feature of several | |
individual observations (series) on a set of numeric variables. | |
Each vertical bar represents a variable and often has its own scale. | |
(The units can even be different). Values are then plotted as series | |
of lines connected across each axis. | |
Parameters | |
---------- | |
df : pd.DataFrame | |
Dataframe containing features and target column. | |
The target column can be binary or multiclass. | |
The plot will show each class of the target column | |
as different color. | |
target : str | |
name of the target column in the Dataframe passed | |
as first parameter | |
fillna : float, optional | |
value to fill if the Dataframe contains null values, | |
by default -1 | |
reverse : Union[List[int], None], optional | |
A list of column indexes [starting with 0] to reverse. | |
The y scale for the given column is reversed so that | |
the max value appear at bottom and min on the top. | |
It useful to simplify the plot so the lines | |
does not appear weavy, by default None | |
figsize : Tuple[int, int], optional | |
figsize for matplotlib figure, by default (15,8) | |
Examples | |
-------- | |
>>> import pandas as pd | |
>>> from sklearn.datasets import load_iris | |
>>> from parallel_coordinates import parallel_coordinates | |
>>> data = load_iris() | |
>>> df = pd.DataFrame(data.data, columns=data.feature_names) | |
>>> df['species'] = [iris.target_names[i] for i in iris.target] | |
>>> parallel_coordinates(df, 'species', reverse=[1]) | |
""" | |
pc_targetname = df[target].unique() | |
pc_target = df[target].astype('category').cat.codes | |
ynames = df.describe().columns | |
ys = df[ynames].fillna(fillna).values | |
ymins = ys.min(axis=0) | |
ymaxs = ys.max(axis=0) | |
dys = ymaxs - ymins | |
ymins -= dys * 0.05 # add 5% padding below and above | |
ymaxs += dys * 0.05 | |
if reverse == None: | |
reverse = [] | |
for idx in reverse: | |
print('reversing' , str(idx)) | |
ymaxs[idx], ymins[idx] = ymins[idx], ymaxs[idx] # reverse axis 1 to have less crossings | |
dys = ymaxs - ymins | |
# transform all data to be compatible with the main axis | |
zs = np.zeros_like(ys) | |
zs[:, 0] = ys[:, 0] | |
zs[:, 1:] = (ys[:, 1:] - ymins[1:]) / dys[1:] * dys[0] + ymins[0] | |
fig, host = plt.subplots(figsize=figsize) | |
axes = [host] + [host.twinx() for i in range(ys.shape[1] - 1)] | |
for i, ax in enumerate(axes): | |
ax.set_ylim(ymins[i], ymaxs[i]) | |
ax.spines['top'].set_visible(False) | |
ax.spines['bottom'].set_visible(False) | |
if ax != host: | |
ax.spines['left'].set_visible(False) | |
ax.yaxis.set_ticks_position('right') | |
ax.spines["right"].set_position(("axes", i / (ys.shape[1] - 1))) | |
host.set_xlim(0, ys.shape[1] - 1) | |
host.set_xticks(range(ys.shape[1])) | |
host.set_xticklabels(ynames, fontsize=14) | |
host.tick_params(axis='x', which='major', pad=7) | |
host.spines['right'].set_visible(False) | |
host.xaxis.tick_top() | |
host.set_title('Parallel Coordinates Plot', fontsize=18, pad=12) | |
colors = plt.cm.Set2.colors | |
legend_handles = [None for _ in pc_targetname] | |
for j in range(ys.shape[0]): | |
# create bezier curves | |
verts = list(zip([x for x in np.linspace(0, len(ys) - 1, len(ys) * 3 - 2, endpoint=True)], | |
np.repeat(zs[j, :], 3)[1:-1])) | |
codes = [Path.MOVETO] + [Path.CURVE4 for _ in range(len(verts) - 1)] | |
path = Path(verts, codes) | |
patch = patches.PathPatch(path, facecolor='none', lw=2, alpha=0.7, edgecolor=colors[pc_target[j]]) | |
legend_handles[pc_target[j]] = patch | |
host.add_patch(patch) | |
host.legend(legend_handles, pc_targetname, | |
loc='lower center', bbox_to_anchor=(0.5, -0.18), | |
ncol=len(pc_targetname), fancybox=True, shadow=True) | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment