Last active
February 20, 2018 19:58
-
-
Save Zsailer/70d47bedb0529be762f0 to your computer and use it in GitHub Desktop.
Make matplotlib plots pretty and standard
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib | |
def prettify(ax, legend_loc=4): | |
""" A simple wrapper to make matplotlib figures prettier.""" | |
# Change default colors to something softer. | |
colors = { | |
'b': '#0066CC', | |
'r': '#CC0000', | |
'm': '#660066', | |
'g': '#009933', | |
'c': '#009999', | |
'y': '#FFCC00', | |
'k': '#333333' | |
} | |
extra_limit_frac = 0.05 | |
spine_widths = 1.35 | |
line_widths = 1.5 | |
errorbars = False | |
# Only prettify the first time. | |
if hasattr(ax, "prettify") is False: | |
ax.prettify = True | |
# Get current axis limits | |
xlimits = list(ax.get_xlim()) | |
ylimits = list(ax.get_ylim()) | |
xticks = list(ax.get_xticks()) | |
yticks = list(ax.get_yticks()) | |
# Extend the graph by 5 percent on all sides | |
xextra = extra_limit_frac*(xlimits[1] - xlimits[0]) | |
yextra = extra_limit_frac*(ylimits[1] - ylimits[0]) | |
# set ticks and tick labels | |
ax.set_xlim(xlimits[0] - xextra, xlimits[1] + xextra) | |
ax.set_ylim(ylimits[0] - yextra, ylimits[1] + yextra) | |
# Remove right and top spines | |
ax.spines['right'].set_visible(False) | |
ax.spines['top'].set_visible(False) | |
# Set the bounds for visible axes | |
ax.spines['bottom'].set_bounds(xlimits[0], xlimits[1]) | |
ax.spines['left'].set_bounds(ylimits[0], ylimits[1]) | |
# Thicken the spines | |
ax.spines['bottom'].set_linewidth(spine_widths) | |
ax.spines['left'].set_linewidth(spine_widths) | |
# Only show ticks on the left and bottom spines | |
ax.yaxis.set_ticks_position('left') | |
ax.xaxis.set_ticks_position('bottom') | |
# Make ticks face outward and thicken them | |
ax.tick_params(direction='out', width=spine_widths) | |
if xticks[-1] > xlimits[1]: | |
xticks = xticks[:-1] | |
if yticks[-1] > ylimits[1]: | |
yticks = yticks[:-1] | |
ax.set_xticks(xticks) | |
ax.set_yticks(yticks) | |
## ---------------------------------------------------- | |
## Styling the data | |
## ---------------------------------------------------- | |
stuff = ax.get_children() | |
# If the first child is a collection, errorbars must be included | |
errorbars = [s for s in stuff if type(s) == matplotlib.collections.LineCollection] | |
lines = [s for s in stuff if type(s) == matplotlib.lines.Line2D] | |
# Change data-line colors and widths | |
line_color = {} | |
for i in range(len(lines)): | |
# Get line data | |
d = lines[i] | |
# Get color and add it to the line_Color dictionary for errobar reference | |
color = d.get_color() | |
line_color[int(i/3)]= color | |
# Set all line markers and edges to same color | |
d.set_color(colors[color]) | |
d.set_markerfacecolor(colors[color]) | |
d.set_markeredgecolor(colors[color]) | |
d.set_markerfacecoloralt(colors[color]) | |
d.set_linewidth(line_widths) | |
# color errorbars with the color from the lines | |
for i in range(len(errorbars)): | |
errorbars[i].set_color(colors[line_color[i]]) | |
## --------------------------- | |
## Styling the legend | |
## --------------------------- | |
# If a legend exists, recreate it. | |
legend = ax.get_legend() | |
if legend is not None: | |
# If errorbars, | |
if len(errorbars) != 0: | |
# get handles | |
handles, labels = ax.get_legend_handles_labels() | |
# remove the errorbars | |
try: | |
handles = [h[0] for h in handles] | |
except TypeError: | |
handles = [h for h in handles] | |
# use them in the legend | |
ax.legend(handles, labels, numpoints=1, frameon=False, loc=legend_loc, fontsize="small") | |
else: | |
ax.legend(frameon=False, loc=legend_loc, fontsize="small") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment