Last active
March 13, 2021 23:39
-
-
Save ruxi/ff0e9255d74a3c187667627214e1f5fa to your computer and use it in GitHub Desktop.
jointplot_w_hue
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__author__ = "[email protected]" | |
__copyright__ = "Copyright 2020, 2018, https://gist.github.com/ruxi/ff0e9255d74a3c187667627214e1f5fa" | |
__license__ = "MIT" | |
__version__ = "0.0.2" | |
# update: June 13, 2020 | |
# created: Feb 19, 2018 | |
# desc: seaborn jointplot with 'hue' | |
# prepared for issue: https://github.com/mwaskom/seaborn/issues/365 | |
# resolved (22 Aug 2020): https://github.com/mwaskom/seaborn/pull/2210 | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
# import matplotlib.lines as mlines | |
# import matplotlib.patches as mpatches | |
def plot_jointgrid_hue(data, x, y, hue | |
, cmap = None #['green', 'orange'] | |
, alphas: list = None #[0.2, 0.5] | |
, alpha = None | |
, marker_map: list = None #['x', '+'] | |
, marker = None | |
, map_plot_margin_x = sns.distplot | |
, map_plot_margin_y = sns.distplot | |
, map_plot_joint = sns.scatterplot | |
, kw_jointgrid = dict() | |
, kw_margins = dict(kde = True) | |
, kw_scatter = dict() | |
): | |
""" | |
seaborn jointgrid with hue | |
returns | |
------- | |
seaborn.axisgrid.JointGrid | |
minimum working example | |
----------------------- | |
iris = sns.load_dataset("iris") | |
g = plot_jointgrid_hue(data=iris, x = 'sepal_length', y = 'sepal_width', hue = 'species') | |
g.fig | |
changelog | |
--------- | |
2020 June 13: Returns JointGrid as a base instead of GridSpec. | |
Include the option to use different alphas and markers | |
for each hue group. | |
2018 Mar 5: added legends and colormap | |
2018 Feb 19: gist made | |
""" | |
#+------------------+ | |
#| default mappings | | |
#+------------------+ | |
if cmap is None: | |
cmap = sns.color_palette() | |
if marker is None: | |
marker = "o" | |
if alpha is None: | |
alpha = 0.5 | |
#+------------------+ | |
#| intialize grid | | |
#+------------------+ | |
grid = sns.JointGrid(data = data, x = x, y = y, **kw_jointgrid) | |
i = -1 | |
legend_handles = [] | |
for k, subset in data.groupby(hue): | |
i +=1 | |
mapped_params = dict(marker = marker_map[i] if type(marker_map)==list else marker | |
, alpha = alphas[i] if type(alphas)==list else alpha | |
, color = cmap[i] | |
) | |
map_plot_margin_x( a = subset[x] | |
, ax=grid.ax_marg_x | |
, color = mapped_params['color'] | |
, **kw_margins | |
) | |
map_plot_margin_y( a = subset[y] | |
, ax=grid.ax_marg_y | |
, vertical=True | |
, color = mapped_params['color'] | |
, **kw_margins | |
) | |
map_plot_joint(data = subset | |
, x = subset[x] | |
, y = subset[y] | |
, ax = grid.ax_joint | |
, **mapped_params | |
, **kw_scatter | |
) | |
#+----------------+ | |
#| legend handles | | |
#+----------------+ | |
# https://matplotlib.org/tutorials/intermediate/legend_guide.html | |
legend_entry, = plt.plot([0] | |
, mapped_params['marker'] | |
, color = mapped_params['color'] | |
, label = str(k)) | |
legend_handles.append(legend_entry) | |
#+-----------------+ | |
#| populate legend | | |
#+-----------------+ | |
grid.fig.legend(title=hue, handles = legend_handles) | |
plt.close() | |
return grid |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I get the following error when choosing a colormap myself:
File "plot_tsne_growthphase_kde.py", line 37, in jointplot_w_hue active_colormap = colormap[0:len(hue_groups)] TypeError: unhashable type: 'slice'