Skip to content

Instantly share code, notes, and snippets.

@vikramsoni2
Last active March 2, 2022 15:22
Show Gist options
  • Save vikramsoni2/9bea4dddd7af2af52538fc832db521a3 to your computer and use it in GitHub Desktop.
Save vikramsoni2/9bea4dddd7af2af52538fc832db521a3 to your computer and use it in GitHub Desktop.
parallel coordinates plot using matplotlib
import numpy as np
import pandas as pd
from matplotlib.path import Path
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
from typing import Union, List, Tuple
def parallel_coordinates(df: pd.DataFrame,
target: str,
fillna: float = -1,
reverse: Union[List[int], None] = None,
figsize: Tuple[int, int] = (15,8)
) -> None:
"""Parallel coordinates plot allows to compare the feature of several
individual observations (series) on a set of numeric variables.
Each vertical bar represents a variable and often has its own scale.
(The units can even be different). Values are then plotted as series
of lines connected across each axis.
Parameters
----------
df : pd.DataFrame
Dataframe containing features and target column.
The target column can be binary or multiclass.
The plot will show each class of the target column
as different color.
target : str
name of the target column in the Dataframe passed
as first parameter
fillna : float, optional
value to fill if the Dataframe contains null values,
by default -1
reverse : Union[List[int], None], optional
A list of column indexes [starting with 0] to reverse.
The y scale for the given column is reversed so that
the max value appear at bottom and min on the top.
It useful to simplify the plot so the lines
does not appear weavy, by default None
figsize : Tuple[int, int], optional
figsize for matplotlib figure, by default (15,8)
Examples
--------
>>> import pandas as pd
>>> from sklearn.datasets import load_iris
>>> from parallel_coordinates import parallel_coordinates
>>> data = load_iris()
>>> df = pd.DataFrame(data.data, columns=data.feature_names)
>>> df['species'] = [iris.target_names[i] for i in iris.target]
>>> parallel_coordinates(df, 'species', reverse=[1])
"""
pc_targetname = df[target].unique()
pc_target = df[target].astype('category').cat.codes
ynames = df.describe().columns
ys = df[ynames].fillna(fillna).values
ymins = ys.min(axis=0)
ymaxs = ys.max(axis=0)
dys = ymaxs - ymins
ymins -= dys * 0.05 # add 5% padding below and above
ymaxs += dys * 0.05
if reverse == None:
reverse = []
for idx in reverse:
print('reversing' , str(idx))
ymaxs[idx], ymins[idx] = ymins[idx], ymaxs[idx] # reverse axis 1 to have less crossings
dys = ymaxs - ymins
# transform all data to be compatible with the main axis
zs = np.zeros_like(ys)
zs[:, 0] = ys[:, 0]
zs[:, 1:] = (ys[:, 1:] - ymins[1:]) / dys[1:] * dys[0] + ymins[0]
fig, host = plt.subplots(figsize=figsize)
axes = [host] + [host.twinx() for i in range(ys.shape[1] - 1)]
for i, ax in enumerate(axes):
ax.set_ylim(ymins[i], ymaxs[i])
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
if ax != host:
ax.spines['left'].set_visible(False)
ax.yaxis.set_ticks_position('right')
ax.spines["right"].set_position(("axes", i / (ys.shape[1] - 1)))
host.set_xlim(0, ys.shape[1] - 1)
host.set_xticks(range(ys.shape[1]))
host.set_xticklabels(ynames, fontsize=14)
host.tick_params(axis='x', which='major', pad=7)
host.spines['right'].set_visible(False)
host.xaxis.tick_top()
host.set_title('Parallel Coordinates Plot', fontsize=18, pad=12)
colors = plt.cm.Set2.colors
legend_handles = [None for _ in pc_targetname]
for j in range(ys.shape[0]):
# create bezier curves
verts = list(zip([x for x in np.linspace(0, len(ys) - 1, len(ys) * 3 - 2, endpoint=True)],
np.repeat(zs[j, :], 3)[1:-1]))
codes = [Path.MOVETO] + [Path.CURVE4 for _ in range(len(verts) - 1)]
path = Path(verts, codes)
patch = patches.PathPatch(path, facecolor='none', lw=2, alpha=0.7, edgecolor=colors[pc_target[j]])
legend_handles[pc_target[j]] = patch
host.add_patch(patch)
host.legend(legend_handles, pc_targetname,
loc='lower center', bbox_to_anchor=(0.5, -0.18),
ncol=len(pc_targetname), fancybox=True, shadow=True)
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment