Last active
May 18, 2022 23:02
-
-
Save TomasDrozdik/484c5cefbce5b6ab0a1ddd177060af2e to your computer and use it in GitHub Desktop.
Example of interval visualisation on a timeline with seaborn-like interface using matplotlib bars.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
import seaborn as sns | |
import matplotlib | |
import matplotlib.pyplot as plt | |
def overlapplot(xstart, xend, ycategory, data=None, hue=None, palette=None, **kwargs): | |
# Load data | |
start = "start" | |
end = "end" | |
cat = "cat" | |
label_color = "label_color" | |
df_intervals = pd.DataFrame() | |
df_intervals[start] = xstart if data is None else data[xstart] | |
df_intervals[end] = xend if data is None else data[xend] | |
df_intervals[cat] = ycategory if data is None else data[ycategory] | |
categories = data[ycategory].unique() | |
# Figure out coloring and labeling | |
if hue: | |
hue = hue if data is None else data[hue] | |
unique_hue = list(hue.unique()) | |
palette = palette if palette else sns.color_palette(n_colors=len(unique_hue)) | |
df_intervals[label_color] = hue.apply( | |
lambda x: (x, palette[unique_hue.index(x)]) | |
) | |
else: | |
df_overlaps = df_intervals.melt( | |
id_vars=[cat], | |
var_name="type", | |
value_vars=[start, end], | |
value_name="time", | |
) | |
df_overlaps["value"] = np.select( | |
[df_overlaps["type"] == start, df_overlaps["type"] == end], [1, -1] | |
) | |
df_overlaps["overlaps"] = ( | |
df_overlaps.sort_values(by=[cat, "time"]) | |
.groupby(cat)["value"] | |
.cumsum() | |
) | |
df_overlaps[start] = df_overlaps["time"] | |
df_overlaps[end] = ( | |
df_overlaps.sort_values(by=[cat, "time"]) | |
.groupby(cat)["time"] | |
.shift(-1, fill_value=None) | |
) | |
df_overlaps.dropna(inplace=True) | |
df_overlaps.drop( | |
df_overlaps[df_overlaps["overlaps"] == 0].index, inplace=True | |
) | |
unique_hue = list(df_overlaps["overlaps"].unique()) | |
palette = palette if palette else sns.color_palette("rocket", len(unique_hue)) | |
df_overlaps[label_color] = df_overlaps["overlaps"].apply( | |
lambda x: (x, palette[unique_hue.index(x)]) | |
) | |
# Pass back to intervals since the format is kept | |
df_intervals = df_overlaps | |
# Filter nonzero ranges | |
df_intervals = df_intervals[df_intervals[start] < df_intervals[end]] | |
# Horizontal bars require xrange tuple in format (x_start, x_length) | |
df_intervals["xrange"] = df_intervals.apply( | |
lambda row: (row[start], row[end] - row[start]), axis=1 | |
) | |
yheight = 1 | |
yheight_shrink = yheight * 0.2 | |
ax = plt.gca(**kwargs) | |
for label, color in df_intervals[label_color].unique(): | |
for category_idx, category in enumerate(categories): | |
xranges = df_intervals[ | |
(df_intervals[label_color] == (label, color)) | |
& (df_intervals[cat] == category) | |
]["xrange"] | |
ax.broken_barh( | |
xranges, | |
yrange=( | |
category_idx * yheight + yheight_shrink, | |
yheight - yheight_shrink, | |
), | |
color=color, | |
) | |
# Create legend | |
legend_elements = [ | |
matplotlib.patches.Patch(facecolor=x[1], label=x[0]) | |
for x in df_intervals[label_color].unique() | |
] | |
ax.legend(handles=legend_elements) | |
ax.set_yticks( | |
[ | |
(category_idx * yheight) + yheight / 2 | |
for category_idx in range(len(categories)) | |
], | |
labels=categories, | |
) | |
return ax | |
xstart = "start" | |
xend = "end" | |
category = "pair" | |
df = pd.concat([ | |
pd.DataFrame({ | |
xstart: np.arange(0, 100, 50), | |
xend: np.arange(10, 110, 50), | |
category: "A" | |
}), | |
pd.DataFrame({ | |
xstart: np.arange(5, 105, 50), | |
xend: np.arange(15, 115, 50), | |
category: "A" | |
}), | |
]) | |
ax = overlapplot(xstart, xend, category, data=df) | |
# More overlapping example | |
n = 10 | |
a = np.random.normal(100, 20, n) | |
b = np.random.binomial(200, 0.5, n) | |
df = pd.concat([ | |
pd.DataFrame({ | |
xstart: a, | |
xend: a + 40, | |
category: "A", | |
}, index=np.arange(0, n)), | |
pd.DataFrame({ | |
xstart: b, | |
xend: b + 10, | |
category: "B", | |
}, index=np.arange(0, n)), | |
]) | |
df = df[df[xstart] < df[xend]] | |
overlapplot(xstart, xend, category, df) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment