Last active
January 18, 2018 03:16
-
-
Save ZGainsforth/f16e032f26994d02e480 to your computer and use it in GitHub Desktop.
One function to do configurable plots in matplotlib. This should make about 90% of plots a one-liner to plot. Allows scaling of size, plotting of multiple lines, and returns the figure and axes in case the user wants to do his own tweaking of the plot.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Created 2014, Zack Gainsforth | |
| import matplotlib | |
| matplotlib.use('Qt4Agg') | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from scipy.interpolate import interp1d | |
| from itertools import chain | |
| def QuickPlot(x, y, linestyle='', xlabel=None, ylabel=None, title=None, legendstrs=None, legendloc=None, savefile=None, boldlevel=1, figax=None, plottype=None, xlim=None, ylim=None, **kwargs): | |
| """ | |
| :param x: Can be a 1D numpy array, or a 2D numpy array. If 1D, then there is one plot. If 2D, every row is the x-axis for another line on the plot. | |
| :param y: Identical dimension to x. | |
| :param xlabel: String (can include TeX) for the label of the x-axis. | |
| :param ylabel: String (can include TeX) for the label of the y-axis. | |
| :param title: String (can include TeX) for the label of the title. | |
| :param legendstrs: List of strings with the same length as x and y, axis 1. | |
| :param savefile: A filename to save to if desired. | |
| :param boldlevel: Sets the linewidths fatter and texts bigger. 1 is normal, 4 is usually good for publication. | |
| :param figax: None makes a new figure and axis. Or plots on an existing one if a tuple is given: (fig, ax) | |
| :return: (fig, ax) The figure and axis so the user can do stuff of his own with it if he wants to. | |
| """ | |
| # Set the bold level. | |
| FontSizeBasis = (boldlevel+2)*4 # Fonts get bigger as boldlevel increases | |
| TickMajorBasis = boldlevel*4 # As fonts get bigger, they need a larger padding from the axis. | |
| # Increase the size of the tick label fonts. | |
| matplotlib.rc('xtick', labelsize=FontSizeBasis) | |
| matplotlib.rc('ytick', labelsize=FontSizeBasis) | |
| # Increase their padding. | |
| matplotlib.rc('xtick.major', pad=TickMajorBasis) | |
| matplotlib.rc('ytick.major', pad=TickMajorBasis) | |
| # Create/reuse a figure to plot on. | |
| if figax is None: | |
| fig, ax = plt.subplots() | |
| else: | |
| fig, ax = figax | |
| # Figure out how many lines are going on this plot. | |
| dim = np.shape(x) | |
| if len(dim) == 1: | |
| # Just one line on this plot. | |
| if plottype is None: | |
| ax.plot(x,y, linestyle, markersize=TickMajorBasis, linewidth=boldlevel, **kwargs) | |
| elif plottype == 'scatter': | |
| ax.scatter(x,y, s=TickMajorBasis**2, **kwargs) | |
| elif plottype == 'semilogy': | |
| ax.semilogy(x,y, linestyle, markersize=TickMajorBasis, linewidth=boldlevel, **kwargs) | |
| elif plottype == 'semilogx': | |
| ax.semilogx(x,y, linestyle, markersize=TickMajorBasis, linewidth=boldlevel, **kwargs) | |
| else: # plottype == 'loglog': | |
| ax.loglog(x,y, linestyle, markersize=TickMajorBasis, linewidth=boldlevel, **kwargs) | |
| elif len(dim) == 2: | |
| # Multiple lines on this plot. | |
| # For multiple lines we cannot use linestyle. | |
| for i in range(dim[0]): | |
| if plottype is None: | |
| ax.plot(x[i,:], y[i,:], linewidth=boldlevel, **kwargs) | |
| elif plottype == 'scatter': | |
| ax.scatter(x,y, s=TickMajorBasis**2, **kwargs) | |
| elif plottype == 'semilogy': | |
| ax.semilogy(x[i,:], y[i,:], linewidth=boldlevel, **kwargs) | |
| elif plottype == 'semilogx': | |
| ax.semilogx(x[i,:], y[i,:], linewidth=boldlevel, **kwargs) | |
| else: # plottype == 'loglog': | |
| ax.loglog(x[i,:], y[i,:], linewidth=boldlevel, **kwargs) | |
| else: | |
| raise TypeError('x is not a 1D or 2D numpy array.') | |
| # Write the x and y labels and the title. | |
| if xlabel is not None: | |
| plt.xlabel(xlabel, fontsize=FontSizeBasis) | |
| if ylabel is not None: | |
| plt.ylabel(ylabel, fontsize=FontSizeBasis) | |
| if title is not None: | |
| plt.title(title, fontsize=FontSizeBasis) | |
| # Show a legend if input. | |
| if legendstrs is not None: | |
| ax.legend(legendstrs, fontsize=FontSizeBasis, loc=legendloc) | |
| # Set the x and y display ranges | |
| if xlim is not None: | |
| plt.xlim(xlim) | |
| if ylim is not None: | |
| plt.ylim(ylim) | |
| # Resize the figure appropriately | |
| try: | |
| plt.draw() | |
| plt.tight_layout() | |
| except: | |
| print('plt.tight_layout() complained ... again ...') | |
| # Save it to a file if a name is given. | |
| if savefile is not None: | |
| plt.savefig(savefile) | |
| return fig, ax | |
| # About 2/3 of my plots are just 2-column data. So make it easy to plot these. | |
| def SpecPlot(S, DivBkg=None, SubBkg=None, **kwargs): | |
| # Make E (xaxis) and I (yaxis) | |
| E = np.copy(S[:,0]) | |
| I = np.copy(S[:,1]) | |
| # Divide by a background spectrum. | |
| if DivBkg is not None: | |
| DivNFunc = interp1d(DivBkg[:,0], DivBkg[:,1], bounds_error=False, fill_value=0) | |
| DivN = DivNFunc(S[:,0]) | |
| I /= DivN | |
| if SubBkg is not None: | |
| SubNFunc = interp1d(SubBkg[:,0], SubBkg[:,1], bounds_error=False, fill_value=0) | |
| SubN = SubNFunc(S[:,0]) | |
| I -= SubN | |
| E = S[:,0] | |
| NewSpectrum = np.vstack((E, I)).T | |
| fig, ax = QuickPlot(E, I, **kwargs) | |
| return(fig, ax, NewSpectrum) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment