Created
July 11, 2022 18:35
-
-
Save matthewcarbone/f5201b1c44963ff9453b9cc1d5f768ac to your computer and use it in GitHub Desktop.
Helper for making nice plots in matplotlib
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib as mpl | |
import matplotlib.pyplot as plt | |
from mpl_toolkits import axes_grid1 | |
class MPLAdjutant: | |
def __init__(self): | |
self.default_DPI = 250 | |
# self.default_labelsize = 14 | |
# self.default_xtick_labelsize = 12 | |
# self.default_ytick_labelsize = 12 | |
self.width = 3.487 | |
self.height = self.width / 1.618 | |
def set_default_font(self, labelsize=12): | |
mpl.rcParams['mathtext.fontset'] = 'stix' | |
mpl.rcParams['font.family'] = 'STIXGeneral' | |
mpl.rcParams['text.usetex'] = True | |
plt.rc('xtick', labelsize=labelsize) | |
plt.rc('ytick', labelsize=labelsize) | |
plt.rc('axes', labelsize=labelsize) | |
def set_defaults(self): | |
mpl.rcParams['figure.dpi'] = self.default_DPI | |
# mpl.rcParams['axes.labelsize'] = self.default_labelsize | |
# plt.rcParams['xtick.labelsize'] = self.default_xtick_labelsize | |
# plt.rcParams['ytick.labelsize'] = self.default_ytick_labelsize | |
self.set_default_font() | |
def set_size_one_column(self, fig, xwidth=1.0, xheight=1.0): | |
fig.set_size_inches(self.width * xwidth, self.height * xheight) | |
def set_size_square(self, fig, xwidth=1.0, xheight=1.0): | |
fig.set_size_inches(self.height * xwidth, self.height * xheight) | |
def set_size_inset(self, fig, xwidth=1.0, xheight=1.0): | |
fig.set_size_inches( | |
self.width * xwidth / 2.0, self.width * xheight / 2.0 | |
) | |
@staticmethod | |
def add_colorbar( | |
im, aspect=10, pad_fraction=0.5, integral_ticks=None, **kwargs | |
): | |
"""Add a vertical color bar to an image plot.""" | |
# https://stackoverflow.com/questions/18195758/set-matplotlib-colorbar-size-to-match-graph | |
divider = axes_grid1.make_axes_locatable(im.axes) | |
width = axes_grid1.axes_size.AxesY(im.axes, aspect=1./aspect) | |
pad = axes_grid1.axes_size.Fraction(pad_fraction, width) | |
current_ax = plt.gca() | |
cax = divider.append_axes("right", size=width, pad=pad) | |
plt.sca(current_ax) | |
cbar = im.axes.figure.colorbar(im, cax=cax, **kwargs) | |
if integral_ticks is not None: | |
L = len(integral_ticks) | |
cbar.set_ticks([ | |
cbar.vmin + (cbar.vmax - cbar.vmin) / L * ii | |
- (cbar.vmax - cbar.vmin) / L / 2.0 for ii in range(1, L + 1) | |
]) | |
cbar.set_ticklabels(integral_ticks) | |
return cbar | |
@staticmethod | |
def _set_lims(ax, low, high, which, threshold): | |
"""Sets the axes limits. | |
Parameters | |
---------- | |
which : {'x', 'y'} | |
threshold : float | |
The percentage margin. | |
""" | |
assert high > low | |
assert which in ['x', 'y'] | |
domain = high - low | |
extend = threshold * domain | |
high += extend | |
low -= extend | |
if which == 'x': | |
ax.set_xlim(low, high) | |
else: | |
ax.set_ylim(low, high) | |
def set_xlim(self, ax, low, high, threshold=0.075): | |
MPLAdjutant._set_lims(ax, low, high, which='x', threshold=threshold) | |
def set_ylim(self, ax, low, high, threshold=0.075): | |
MPLAdjutant._set_lims(ax, low, high, which='y', threshold=threshold) | |
def set_xtick_spacing(self, ax, spacing): | |
ax.xaxis.set_major_locator(mpl.ticker.MultipleLocator(spacing)) | |
def set_ytick_spacing(self, ax, spacing): | |
ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(spacing)) | |
@staticmethod | |
def set_grids( | |
ax, minorticks=True, grid=False, bottom=True, left=True, right=True, | |
top=True | |
): | |
if minorticks: | |
ax.minorticks_on() | |
ax.tick_params( | |
which='both', direction='in', bottom=bottom, left=left, | |
top=top, right=right | |
) | |
if grid: | |
ax.grid(which='minor', alpha=0.2, linestyle=':') | |
ax.grid(which='major', alpha=0.5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment