Skip to content

Instantly share code, notes, and snippets.

@salotz
Created March 1, 2019 20:22
Show Gist options
  • Save salotz/8b4542d7fe9ea3e2eacc1a2eef2532c5 to your computer and use it in GitHub Desktop.
Save salotz/8b4542d7fe9ea3e2eacc1a2eef2532c5 to your computer and use it in GitHub Desktop.
Move a matplotlib Axes from one figure to another.
import matplotlib.pyplot as plt
def move_axes(ax, fig, subplot_spec=111):
"""Move an Axes object from a figure to a new pyplot managed Figure in
the specified subplot."""
# get a reference to the old figure context so we can release it
old_fig = ax.figure
# remove the Axes from it's original Figure context
ax.remove()
# set the pointer from the Axes to the new figure
ax.figure = fig
# add the Axes to the registry of axes for the figure
fig.axes.append(ax)
# twice, I don't know why...
fig.add_axes(ax)
# then to actually show the Axes in the new figure we have to make
# a subplot with the positions etc for the Axes to go, so make a
# subplot which will have a dummy Axes
dummy_ax = fig.add_subplot(subplot_spec)
# then copy the relevant data from the dummy to the ax
ax.set_position(dummy_ax.get_position())
# then remove the dummy
dummy_ax.remove()
# close the figure the original axis was bound to
plt.close(old_fig)
@digitalsignalperson
Copy link

Re: Line 27, I just created a new figure to move all axes onto so that all subplots have the same weird offsets if any

e.g.

all_axes = []
for thing in things:
    axes = make_complicated_plot()
    all_axes.append(axes)

fig = figure()
old_figs = []
for i_row, axes in enumerate(all_axes):
    nrows = len(all_axes)
    ncols = len(axes[0])
    for i_col, ax in enumerate(axes[0]):
        old_fig = move_axes(ax, fig, (nrows, ncols, 1 + i_col + i_row * ncols))
        if i_col == 0:
            old_figs.append(old_fig)

for old_fig in old_figs:
    plt.close(old_fig)

with a couple small tweaks in my fork of this gist

Matplotlib subfigures are also cool for stuff like this, but I found them still a little janky.

@engeir
Copy link

engeir commented Jun 20, 2024

@digitalsignalperson What does make_complicated_plot() return? It seems like it should be a matplotlib.axes.Axes, but then you compute len(axes[0]), so the return of make_complicated_plot() must then be subscriptable.

@digitalsignalperson
Copy link

It was using a pandas.DataFrame.plot() return value of a numpy array of Axes

more details:

figsize = (23, 11.5)
nrows = 4
ncols = 7

all_axes = []
for category in categories:
    # group/filter a dataframe based on the category. assume it is df here
    axes = df.plot(kind='pie', subplots=True, layout=(nrows, ncols), figsize=figsize, legend=False, ylabel='', title='example')
    all_axes.append((category, axes))

fig = figure(figsize=figsize)
# subplots_adjust(...) as needed

old_figs = []
for i_row, (cat, axes) in enumerate(all_axes):
    for i_col, ax in enumerate(axes[0]):
        old_fig = move_axes(ax, fig, (nrows, ncols, 1 + i_col + i_row * 7))
        if i_col == 0:
            old_figs.append(old_fig)
            ax.text(-0.25, 0.5, cat, transform=ax.transAxes, rotation=90, ha='center', va='center')
for old_fig in old_figs:
    plt.close(old_fig)

sorry it's not a complete example, but hopefully it works if you play with it

@digitalsignalperson
Copy link

@engeir
Copy link

engeir commented Jun 20, 2024

Right, I see. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment