Last active
July 3, 2016 14:06
-
-
Save raggleton/329b8cdaf48261a22d1a105bd895c661 to your computer and use it in GitHub Desktop.
Functions/decorators to help make saving plots in ipython notebooks easier.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from copy import deepcopy | |
from contextlib import contextmanager | |
# Global bool to turn on/off plot saving for all plots in notebook | |
SAVE_PLOTS = True | |
def save_plot(filename): | |
"""Save the plot. Auto creates dirs if necessary.""" | |
filename = os.path.abspath(filename) | |
plot_dir = os.path.dirname(filename) | |
if not os.path.isdir(plot_dir): | |
os.makedirs(plot_dir) | |
plt.savefig(filename) | |
def get_backend(): | |
return list(get_ipython().magic(u"config InlineBackend.figure_formats")) | |
def set_backend(fmt): | |
print 'Setting backend', fmt | |
get_ipython().magic(u"config InlineBackend.figure_formats = '%s', " % fmt) | |
@contextmanager | |
def backend(fmt): | |
old_fmt = get_backend()[0] | |
set_backend(fmt) | |
yield | |
set_backend(old_fmt) | |
def save_fig(func): | |
"""Decorator func for easily saving plots in jupyter notebook. | |
Just add in a filename=XXX to your normal function arguments. | |
It will auto switch backends correctly. | |
""" | |
def wrapper(*args, **kwargs): | |
filename = kwargs['filename'] | |
if filename == '': | |
raise IOError('No filename') | |
if '.' not in filename: | |
raise IOError('No extension in filename: %s' % filename) | |
fmt = os.path.splitext(filename)[1].lstrip('.') if SAVE_PLOTS else 'png' | |
with backend(fmt): | |
new_kwargs = deepcopy(kwargs) | |
del new_kwargs['filename'] | |
func(*args, **new_kwargs) | |
if SAVE_PLOTS: | |
save_plot(filename) | |
return wrapper | |
# EXAMPLE: | |
@save_fig | |
def plot_thing(title): | |
plt.plot([1, 2], [3, 4]) | |
plt.suptitle(title) | |
plot_thing('Silly plot', filename='thing.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment