Skip to content

Instantly share code, notes, and snippets.

@tmramalho
Created October 26, 2018 09:37
Show Gist options
  • Save tmramalho/dc9d76dad4cea21114122431a034f80c to your computer and use it in GitHub Desktop.
Save tmramalho/dc9d76dad4cea21114122431a034f80c to your computer and use it in GitHub Desktop.
Show a 4D numpy array as a 2D image
def image_grid(data, n_cols=None, zoom=1):
if n_cols is None:
n_cols = int(np.ceil(np.sqrt(data.shape[0])))
n_rows = int(np.ceil(data.shape[0]/n_cols))
target = np.zeros((data.shape[1]*n_rows, data.shape[2]*n_cols, data.shape[3]), dtype=data.dtype)
flat_data = data.swapaxes(1,2).reshape((data.shape[0]*data.shape[1], data.shape[2], data.shape[3])).swapaxes(0, 1)
for i in range(n_rows):
start_y = i*data.shape[2]*n_cols
end_y = (i+1)*data.shape[2]*n_cols
start_x = i*data.shape[1]
end_x = (i+1)*data.shape[1]
crop = flat_data[:, start_y:end_y, :]
target[start_x:end_x, 0:crop.shape[1], :] = crop
for _ in range(int(np.log2(zoom))):
new_target = np.zeros((target.shape[0]*2, target.shape[1]*2, target.shape[2]), dtype=data.dtype)
new_target[0::2, 0::2] = target
new_target[1::2, 0::2] = target
new_target[:, 1::2] = new_target[:, 0::2]
target = new_target
target = np.squeeze(target)
return target
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment