Skip to content

Instantly share code, notes, and snippets.

@nlintz
Created March 11, 2016 19:34
Show Gist options
  • Save nlintz/f5bde5df8eee57472954 to your computer and use it in GitHub Desktop.
Save nlintz/f5bde5df8eee57472954 to your computer and use it in GitHub Desktop.
Utilities for converting an image into a sequence of non-overlapping patches and back to an image again
import numpy as np
from skimage.util import view_as_blocks
def round_down(num, divisor):
return num - (num % divisor)
def crop_center(img, new_rows, new_cols):
rows, cols, c = img.shape
top = (rows - new_rows)/2
bot = (rows + new_rows)/2
left = (cols - new_cols)/2
right = (cols + new_cols)/2
return img[top:bot, left:right, :]
def crop_image_to_patch(img, patch_shape):
""" center crops an image so its shape can be broken down
into patch_shaped chunks
img: [rows, cols, nchannels]
patch_shape: (rows, cols, nchannels)
"""
rows, cols, c = img.shape
new_rows = round_down(rows, patch_shape[0])
new_cols = round_down(cols, patch_shape[1])
return crop_center(img, new_rows, new_cols)
def image_to_patches(img, patch_shape):
"""
converts an image into M patch_shaped chunks
arguments --
img: [rows, cols, nchannels]
patch_shape: (rows, cols, nchannels)
returns --
[M, patch_shape[0], patch_shape[1], nchannels]
"""
blocks = view_as_blocks(img, patch_shape)
return blocks.reshape((-1,) + patch_shape, order="F")
def patches_to_image(patches, patch_shape, img_shape):
"""
converts result from image_to_patches back into an image of the
original image's shape
patches: [n_patches, rows, cols, nchannels]
patch_shape: (rows, cols, nchannels)
img_shape: (rows, cols, nchannels)
"""
patches_per_col = img_shape[0] / patch_shape[0]
res = []
for i in range(0, patches_per_col):
res.append(np.concatenate(patches[i::patches_per_col], axis=1))
return img
if __name__ == "__main__":
# Example Script
import matplotlib.pyplot as plt
patch_size = (2, 2, 3)
img = np.repeat(np.repeat(np.random.rand(4, 4, 3), 2, axis=0), 2, axis=1)
img = crop_image_to_patch(img, patch_size)
patches = image_to_patches(img, patch_size)
plt.subplot(311)
plt.imshow(img, interpolation="None")
plt.subplot(312)
plt.imshow(np.concatenate(patches), interpolation="None")
plt.subplot(313)
plt.imshow(patches_to_image(patches, patch_size, img.shape), interpolation="None")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment