-
-
Save xigrug/7c675cd3739122da2e91a16bd1f67ab4 to your computer and use it in GitHub Desktop.
jointplot_w_hue
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__author__ = "[email protected]" | |
__copyright__ = "Copyright 2018, github.com/ruxi" | |
__license__ = "MIT" | |
__version__ = 0.0.1 | |
# update: Mar 5 , 2018 | |
# created: Feb 19, 2018 | |
# desc: seaborn jointplot with 'hue' | |
# prepared for issue: https://github.com/mwaskom/seaborn/issues/365 | |
""" | |
jointplots with hue groupings. | |
minimum working example | |
----------------------- | |
iris = sns.load_dataset("iris") | |
jointplot_w_hue(data=iris, x = 'sepal_length', y = 'sepal_width', hue = 'species')['fig'] | |
changelog | |
--------- | |
2018 Mar 5: added legends and colormap | |
2018 Feb 19: gist made | |
""" | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
import matplotlib.gridspec as gridspec | |
import matplotlib.patches as mpatches | |
sns.set_style('darkgrid') | |
def jointplot_w_hue(data, x, y, hue=None, colormap = None, | |
figsize = None, fig = None, scatter_kws=None): | |
#defaults | |
if colormap is None: | |
colormap = sns.color_palette() #['blue','orange'] | |
if figsize is None: | |
figsize = (5,5) | |
if fig is None: | |
fig = plt.figure(figsize = figsize) | |
if scatter_kws is None: | |
scatter_kws = dict(alpha=0.4, lw=1) | |
# derived variables | |
if hue is None: | |
return "use normal sns.jointplot" | |
hue_groups = data[hue].unique() | |
subdata = dict() | |
colors = dict() | |
active_colormap = colormap[0: len(hue_groups)] | |
legend_mapping = [] | |
for hue_grp, color in zip(hue_groups, active_colormap): | |
legend_entry = mpatches.Patch(color=color, label=hue_grp) | |
legend_mapping.append(legend_entry) | |
subdata[hue_grp] = data[data[hue]==hue_grp] | |
colors[hue_grp] = color | |
# canvas setup | |
grid = gridspec.GridSpec(2, 2, | |
width_ratios=[4, 1], | |
height_ratios=[1, 4], | |
hspace = 0, wspace = 0 | |
) | |
ax_main = plt.subplot(grid[1,0]) | |
ax_xhist = plt.subplot(grid[0,0], sharex=ax_main) | |
ax_yhist = plt.subplot(grid[1,1])#, sharey=ax_main) | |
## plotting | |
# histplot x-axis | |
for hue_grp in hue_groups: | |
sns.distplot(subdata[hue_grp][x], color = colors[hue_grp] | |
, ax = ax_xhist) | |
# histplot y-axis | |
for hue_grp in hue_groups: | |
sns.distplot(subdata[hue_grp][y], color = colors[hue_grp] | |
, ax = ax_yhist, vertical=True) | |
# main scatterplot | |
# note: must be after the histplots else ax_yhist messes up | |
for hue_grp in hue_groups: | |
sns.regplot(data = subdata[hue_grp], fit_reg=False, | |
x = x, y = y, ax = ax_main, color = colors[hue_grp] | |
, scatter_kws=scatter_kws | |
) | |
# despine | |
for myax in [ax_yhist, ax_xhist]: | |
sns.despine(ax = myax, bottom=False, top=True, left = False, right = True | |
, trim = False) | |
plt.setp(myax.get_xticklabels(), visible=False) | |
plt.setp(myax.get_yticklabels(), visible=False) | |
# topright | |
ax_legend = plt.subplot(grid[0,1])#, sharey=ax_main) | |
plt.setp(ax_legend.get_xticklabels(), visible=False) | |
plt.setp(ax_legend.get_yticklabels(), visible=False) | |
ax_legend.legend(handles=legend_mapping) | |
plt.close() | |
return dict(fig = fig, gridspec = grid) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment