Skip to content

Instantly share code, notes, and snippets.

@karolzak
Last active January 16, 2019 22:27
Show Gist options
  • Save karolzak/34725ba2c1f50db836a9b1c2886dc7a1 to your computer and use it in GitHub Desktop.
Save karolzak/34725ba2c1f50db836a9b1c2886dc7a1 to your computer and use it in GitHub Desktop.
returns crops out of single image (numpy) or array of images
def get_patches(img_arr, size=256, stride=256):
'''
Takes single image or array of images and returns
crops using sliding window method.
If stride < size it will do overlapping.
'''
# check size and stride
if size % stride != 0:
raise ValueError('size % stride must be equal 0')
patches_list = []
overlapping = 0
if stride != size:
overlapping = (size // stride) - 1
if img_arr.ndim == 3:
i_max = img_arr.shape[0] // stride-overlapping
for i in range(i_max):
for j in range(i_max):
#print(i*stride, i*stride+size)
#print(j*stride, j*stride+size)
patches_list.append(
img_arr[i*stride:i*stride+size,
j*stride:j*stride+size
])
elif img_arr.ndim == 4:
i_max = img_arr.shape[1] // stride - overlapping
for im in img_arr:
for i in range(i_max):
for j in range(i_max):
#print(i*stride, i*stride+size)
#print(j*stride, j*stride+size)
patches_list.append(
im[i*stride:i*stride+size,
j*stride:j*stride+size
])
else:
raise ValueError('img_arr.ndim must be equal 3 or 4')
return np.stack(patches_list)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment