Skip to content

Instantly share code, notes, and snippets.

@Mason-McGough
Last active March 15, 2019 03:33
Show Gist options
  • Save Mason-McGough/5e9d928ac39cc5bde617feb7e6db9480 to your computer and use it in GitHub Desktop.
Save Mason-McGough/5e9d928ac39cc5bde617feb7e6db9480 to your computer and use it in GitHub Desktop.
A collection of functions for plotting using Matplotlib
from .numpy_utils import subsample
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def scatter_3d(X, sample_size=None, fig=None, subplot=None, title='X', xlabel='X',
ylabel='Y', zlabel='Z'):
"""
Create 3d scatterplot of N-by-3 array.
Inputs:
x - The N-by-3 array to visualize.
sample_size - The number of samples to select from x. If None, the whole
array X is used. Useful if N of X is very large. (Default: None)
fig - The Matplotlib figure to display the plot. If None, a new figure is
created. (Default: None)
subplot - The subplot positions consumed by Figure.add_subplot. If None, no
subplot is created. (Default: None)
title - The plot title.
xlabel - The label applied to the x-axis.
ylabel - The label applied to the y-axis.
zlabel - The label applied to the z-axis.
Outputs:
x - The subsampled array.
"""
if not X.ndim == 2:
raise ValueError('Shape of X must be two dimensions. (shape: {})'.format(X.shape))
if not X.shape[1] == 3:
raise ValueError('Number of columns of X must be equal to 3. (shape: {})'.format(
X.shape))
if fig is None and subplot is not None:
raise ValueError('If subplot is set, then fig must be provided.')
if fig is None and subplot is None:
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], projection='3d')
elif subplot is None:
ax = fig.add_axes([0, 0, 1, 1], projection='3d')
else:
if isinstance(subplot, list):
ax = fig.add_subplot(*subplot, projection='3d')
else:
ax = fig.add_subplot(subplot, projection='3d')
if sample_size is not None:
X = subsample(X, sample_size)
ax.scatter(X[:, 0], X[:, 1], X[:, 2])
ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_zlabel(zlabel)
return ax
def grid_imshow(grid, imgs_list, imgs2_list):
"""
Plot a grid of subplots with paired images.
Inputs:
grid - The list of dimensions of the grid.
imgs_list - The list of images to plot.
imgs2_list - The second list of images to plot. Must be the same length as
imgs_list.
Outputs:
None
"""
for r in range(grid[0]):
for c in range(grid[1]):
img = imgs_list[grid[1] * r + c]
img2 = imgs2_list[grid[1] * r + c]
plt.subplot(2 * grid[0], grid[1], 2 * r * grid[1] + c + 1)
plt.imshow(img)
plt.subplot(2 * grid[0], grid[1], (2 * r + 1) * grid[1] + c + 1)
plt.imshow(img2)
def plot_fXY(X, Y, fn, linewidth=0):
"""
Plot surface defined by function over X-Y plane.
Example:
X = np.arange(-5, 5, 0.025)
Y = np.arange(-5, 5, 0.025)
fn = lambda X, Y : np.sin(np.sqrt(X**2 + Y**2))
plot_fXY(X, Y, fn)
Inputs:
X - Array of samples along the x-axis.
Y - Array of samples along the y-axis.
fn - The function to plot. It should accept inputs in the form fn(X, Y).
linewidth - The width of the surface's grid. (Default: 0)
Output:
None
"""
X, Y = np.meshgrid(X, Y)
Z = fn(X, Y)
# Create figure.
fig = plt.figure()
ax = fig.gca(projection='3d')
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm, linewidth=linewidth, antialiased=False)
# Customize the z axis.
ax.zaxis.set_major_locator(LinearLocator(10))
ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
# Add a color bar which maps values to colors.
fig.colorbar(surf, shrink=0.5, aspect=5)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment