Last active
March 19, 2019 20:49
-
-
Save wookayin/db1800194c0e44b12316696376ecd01a to your computer and use it in GitHub Desktop.
IPython notebook snippet for plotting multiple images in a grid.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# in a courtesy of Caffe's filter visualization example | |
# http://nbviewer.jupyter.org/github/BVLC/caffe/blob/master/examples/00-classification.ipynb | |
def imshow_grid(data, height=None, width=None, normalize=False, padsize=1, padval=0): | |
''' | |
Take an array of shape (N, H, W) or (N, H, W, C) | |
and visualize each (H, W) image in a grid style (height x width). | |
''' | |
if normalize: | |
data -= data.min() | |
data /= data.max() | |
N = data.shape[0] | |
if height is None: | |
if width is None: | |
height = int(np.ceil(np.sqrt(N))) | |
else: | |
height = int(np.ceil( N / float(width) )) | |
if width is None: | |
width = int(np.ceil( N / float(height) )) | |
assert height * width >= N | |
# append padding | |
padding = ((0, (width*height) - data.shape[0]), (0, padsize), (0, padsize)) + ((0, 0),) * (data.ndim - 3) | |
data = np.pad(data, padding, mode='constant', constant_values=(padval, padval)) | |
# tile the filters into an image | |
data = data.reshape((height, width) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1))) | |
data = data.reshape((height * data.shape[1], width * data.shape[3]) + data.shape[4:]) | |
plt.imshow(data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment