Skip to content

Instantly share code, notes, and snippets.

@patricksnape
Created March 9, 2015 20:58
Show Gist options
  • Save patricksnape/1ce75c3cd44736266aeb to your computer and use it in GitHub Desktop.
Save patricksnape/1ce75c3cd44736266aeb to your computer and use it in GitHub Desktop.
TPS Batched Warps
from __future__ import division
def warp_to_shape(image, template_shape, transform, batch_size=5000,
warp_landmarks=False, order=1, mode='constant', cval=0.):
"""
Return a copy of this image warped into a different reference space.
Parameters
----------
image : map:`Image`
The image to warp.
template_shape : `tuple` or `ndarray`
Defines the shape of the result, and what pixel indices should be
sampled (all of them).
transform : :map:`Transform`
Transform **from the template_shape space back to this image**.
Defines, for each index on template_shape, which pixel location
should be sampled from on this image.
batch_size : int
The new important parameter! This determines how many points should
be warped at one time. Lower is much slower, but uses less memory.
warp_landmarks : `bool`, optional
If ``True``, result will have the same landmark dictionary
as self, but with each landmark updated to the warped position.
order : `int`, optional
The order of interpolation. The order has to be in the range [0,5]
========= ====================
Order Interpolation
========= ====================
0 Nearest-neighbor
1 Bi-linear *(default)*
2 Bi-quadratic
3 Bi-cubic
4 Bi-quartic
5 Bi-quintic
========= ====================
mode : ``{constant, nearest, reflect, wrap}``, optional
Points outside the boundaries of the input are filled according
to the given mode.
cval : `float`, optional
Used in conjunction with mode ``constant``, the value outside
the image boundaries.
Returns
-------
warped_image : `type(image)`
A copy of this image, warped.
"""
import numpy as np
from menpo.transform import Affine
from menpo.image import Image
from menpo.image.interpolation import (scipy_interpolation,
cython_interpolation)
from menpo.image.base import indices_for_image_of_shape
if (isinstance(transform, Affine) and order in range(4) and image.n_dims == 2):
# skimage has an optimised Cython interpolation for 2D affine
# warps
sampled = cython_interpolation(image.pixels, template_shape,
transform, order=order,
mode=mode, cval=cval)
else:
template_points = indices_for_image_of_shape(template_shape)
sampled = np.zeros([template_points.shape[0], image.n_channels])
for lo_ind in range(0, template_points.shape[0], batch_size):
hi_ind = lo_ind + batch_size
points_to_sample = transform.apply(template_points[lo_ind:hi_ind])
# we want to sample each channel in turn, returning a vector of
# sampled pixels. Store those in a (n_channels, n_pixels) array.
sampled[lo_ind:hi_ind, :] = scipy_interpolation(image.pixels,
points_to_sample,
order=order,
mode=mode,
cval=cval)
# set any nan values to 0
sampled[np.isnan(sampled)] = 0
# build a warped version of the image
warped_pixels = sampled.reshape(template_shape + (image.n_channels,))
warped_image = Image(warped_pixels, copy=False)
# warp landmarks if requested.
if warp_landmarks and image.has_landmarks:
warped_image.landmarks = image.landmarks
transform.pseudoinverse().apply_inplace(warped_image.landmarks)
if hasattr(image, 'path'):
warped_image.path = image.path
return warped_image
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment