Created
July 26, 2018 09:50
-
-
Save dkapitan/fcf45a97caaf48bc3d6be17b5f8b213c to your computer and use it in GitHub Desktop.
Multiple Seaborn plots in a grid
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://stackoverflow.com/questions/35042255/how-to-plot-multiple-seaborn-jointplot-in-subplot | |
import matplotlib.pyplot as plt | |
import matplotlib.gridspec as gridspec | |
import seaborn as sns | |
import numpy as np | |
class SeabornFig2Grid(): | |
def __init__(self, seaborngrid, fig, subplot_spec): | |
self.fig = fig | |
self.sg = seaborngrid | |
self.subplot = subplot_spec | |
if (isinstance(self.sg, sns.axisgrid.FacetGrid) or isinstance(self.sg, sns.axisgrid.PairGrid)): | |
self._movegrid() | |
elif isinstance(self.sg, sns.axisgrid.JointGrid): | |
self._movejointgrid() | |
self._finalize() | |
def _movegrid(self): | |
""" Move PairGrid or Facetgrid """ | |
self._resize() | |
n = self.sg.axes.shape[0] | |
m = self.sg.axes.shape[1] | |
self.subgrid = gridspec.GridSpecFromSubplotSpec( | |
n, m, subplot_spec=self.subplot) | |
for i in range(n): | |
for j in range(m): | |
self._moveaxes(self.sg.axes[i, j], self.subgrid[i, j]) | |
def _movejointgrid(self): | |
""" Move Jointgrid """ | |
h = self.sg.ax_joint.get_position().height | |
h2 = self.sg.ax_marg_x.get_position().height | |
r = int(np.round(h / h2)) | |
self._resize() | |
self.subgrid = (gridspec.GridSpecFromSubplotSpec(r + 1, r + 1, | |
subplot_spec=self.subplot)) | |
self._moveaxes(self.sg.ax_joint, self.subgrid[1:, :-1]) | |
self._moveaxes(self.sg.ax_marg_x, self.subgrid[0, :-1]) | |
self._moveaxes(self.sg.ax_marg_y, self.subgrid[1:, -1]) | |
def _moveaxes(self, ax, gs): | |
# https://stackoverflow.com/a/46906599/4124317 | |
ax.remove() | |
ax.figure = self.fig | |
self.fig.axes.append(ax) | |
self.fig.add_axes(ax) | |
ax._subplotspec = gs | |
ax.set_position(gs.get_position(self.fig)) | |
ax.set_subplotspec(gs) | |
def _finalize(self): | |
plt.close(self.sg.fig) | |
self.fig.canvas.mpl_connect("resize_event", self._resize) | |
self.fig.canvas.draw() | |
def _resize(self, evt=None): | |
self.sg.fig.set_size_inches(self.fig.get_size_inches()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@dkapitan Hi Daniel, thank you so much for sharing this package with me. I just realized that I did plot out the marginal distributions but somehow the color and linewidth for the distributions are too light/small to show. I checked the pdf files on the browser before and I could not see the distribution plots. I later downloaded the file, zoomed in the pdf, and then I was able to figure this out. Sincerely thank you again for your quick response!