-
-
Save insujeon/37c7be2adf94b62160337525d72aa456 to your computer and use it in GitHub Desktop.
Improved image transform functions for dense predictions (for pytorch, keras etc.)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy | |
import scipy.ndimage | |
from scipy.ndimage.filters import gaussian_filter | |
from scipy.ndimage.interpolation import map_coordinates | |
import collections | |
from PIL import Image | |
import numbers | |
__author__ = "Wei OUYANG" | |
__license__ = "GPL" | |
__version__ = "0.1.0" | |
__status__ = "Development" | |
def center_crop(x, center_crop_size): | |
assert x.ndim == 3 | |
centerw, centerh = x.shape[1] // 2, x.shape[2] // 2 | |
halfw, halfh = center_crop_size[0] // 2, center_crop_size[1] // 2 | |
return x[:, centerw - halfw:centerw + halfw, centerh - halfh:centerh + halfh] | |
def to_tensor(x): | |
import torch | |
x = x.transpose((2, 0, 1)) | |
return torch.from_numpy(x).float() | |
def random_num_generator(config, random_state=np.random): | |
if config[0] == 'uniform': | |
ret = random_state.uniform(config[1], config[2], 1)[0] | |
elif config[0] == 'lognormal': | |
ret = random_state.lognormal(config[1], config[2], 1)[0] | |
else: | |
print(config) | |
raise Exception('unsupported format') | |
return ret | |
def poisson_downsampling(image, peak, random_state=np.random): | |
if not isinstance(image, np.ndarray): | |
imgArr = np.array(image, dtype='float32') | |
else: | |
imgArr = image.astype('float32') | |
Q = imgArr.max(axis=(0, 1)) / peak | |
if Q[0] == 0: | |
return imgArr | |
ima_lambda = imgArr / Q | |
noisy_img = random_state.poisson(lam=ima_lambda) | |
return noisy_img.astype('float32') | |
def elastic_transform(image, alpha=1000, sigma=30, spline_order=1, mode='nearest', random_state=np.random): | |
"""Elastic deformation of image as described in [Simard2003]_. | |
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for | |
Convolutional Neural Networks applied to Visual Document Analysis", in | |
Proc. of the International Conference on Document Analysis and | |
Recognition, 2003. | |
""" | |
assert image.ndim == 3 | |
shape = image.shape[:2] | |
dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), | |
sigma, mode="constant", cval=0) * alpha | |
dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), | |
sigma, mode="constant", cval=0) * alpha | |
x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') | |
indices = [np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))] | |
result = np.empty_like(image) | |
for i in range(image.shape[2]): | |
result[:, :, i] = map_coordinates( | |
image[:, :, i], indices, order=spline_order, mode=mode).reshape(shape) | |
return result | |
class Merge(object): | |
"""Merge a group of images | |
""" | |
def __init__(self, axis=-1): | |
self.axis = axis | |
def __call__(self, images): | |
if isinstance(images, collections.Sequence) or isinstance(images, np.ndarray): | |
assert all([isinstance(i, np.ndarray) | |
for i in images]), 'only numpy array is supported' | |
shapes = [list(i.shape) for i in images] | |
for s in shapes: | |
s[self.axis] = None | |
assert all([s == shapes[0] for s in shapes] | |
), 'shapes must be the same except the merge axis' | |
return np.concatenate(images, axis=self.axis) | |
else: | |
raise Exception("obj is not a sequence (list, tuple, etc)") | |
class Split(object): | |
"""Split images into individual arraies | |
""" | |
def __init__(self, *slices, **kwargs): | |
assert isinstance(slices, collections.Sequence) | |
slices_ = [] | |
for s in slices: | |
if isinstance(s, collections.Sequence): | |
slices_.append(slice(*s)) | |
else: | |
slices_.append(s) | |
assert all([isinstance(s, slice) for s in slices_] | |
), 'slices must be consist of slice instances' | |
self.slices = slices_ | |
self.axis = kwargs.get('axis', -1) | |
def __call__(self, image): | |
if isinstance(image, np.ndarray): | |
ret = [] | |
for s in self.slices: | |
sl = [slice(None)] * image.ndim | |
sl[self.axis] = s | |
ret.append(image[sl]) | |
return ret | |
else: | |
raise Exception("obj is not an numpy array") | |
class ElasticTransform(object): | |
"""Apply elastic transformation on a numpy.ndarray (H x W x C) | |
""" | |
def __init__(self, alpha, sigma): | |
self.alpha = alpha | |
self.sigma = sigma | |
def __call__(self, image): | |
if isinstance(self.alpha, collections.Sequence): | |
alpha = random_num_generator(self.alpha) | |
else: | |
alpha = self.alpha | |
if isinstance(self.sigma, collections.Sequence): | |
sigma = random_num_generator(self.sigma) | |
else: | |
sigma = self.sigma | |
return elastic_transform(image, alpha=alpha, sigma=sigma) | |
class PoissonSubsampling(object): | |
"""Poisson subsampling on a numpy.ndarray (H x W x C) | |
""" | |
def __init__(self, peak, random_state=np.random): | |
self.peak = peak | |
self.random_state = random_state | |
def __call__(self, image): | |
if isinstance(self.peak, collections.Sequence): | |
peak = random_num_generator( | |
self.peak, random_state=self.random_state) | |
else: | |
peak = self.peak | |
return poisson_downsampling(image, peak, random_state=self.random_state) | |
class AddGaussianNoise(object): | |
"""Add gaussian noise to a numpy.ndarray (H x W x C) | |
""" | |
def __init__(self, mean, sigma, random_state=np.random): | |
self.sigma = sigma | |
self.mean = mean | |
self.random_state = random_state | |
def __call__(self, image): | |
if isinstance(self.sigma, collections.Sequence): | |
sigma = random_num_generator( | |
self.sigma, random_state=self.random_state) | |
else: | |
sigma = self.sigma | |
if isinstance(self.mean, collections.Sequence, random_state=self.random_state): | |
mean = random_num_generator(self.mean) | |
else: | |
mean = self.mean | |
row, col, ch = image.shape | |
gauss = self.random_state.normal(mean, sigma, (row, col, ch)) | |
gauss = gauss.reshape(row, col, ch) | |
image += gauss | |
return image | |
class AddSpeckleNoise(object): | |
"""Add speckle noise to a numpy.ndarray (H x W x C) | |
""" | |
def __init__(self, mean, sigma, random_state=np.random): | |
self.sigma = sigma | |
self.mean = mean | |
self.random_state = random_state | |
def __call__(self, image): | |
if isinstance(self.sigma, collections.Sequence): | |
sigma = random_num_generator( | |
self.sigma, random_state=self.random_state) | |
else: | |
sigma = self.sigma | |
if isinstance(self.mean, collections.Sequence): | |
mean = random_num_generator( | |
self.mean, random_state=self.random_state) | |
else: | |
mean = self.mean | |
row, col, ch = image.shape | |
gauss = self.random_state.normal(mean, sigma, (row, col, ch)) | |
gauss = gauss.reshape(row, col, ch) | |
image += image * gauss | |
return image | |
class GaussianBlurring(object): | |
"""Apply gaussian blur to a numpy.ndarray (H x W x C) | |
""" | |
def __init__(self, sigma, random_state=np.random): | |
self.sigma = sigma | |
self.random_state = random_state | |
def __call__(self, image): | |
if isinstance(self.sigma, collections.Sequence): | |
sigma = random_num_generator( | |
self.sigma, random_state=self.random_state) | |
else: | |
sigma = self.sigma | |
image = gaussian_filter(image, sigma=(sigma, sigma, 0)) | |
return image | |
class AddGaussianPoissonNoise(object): | |
"""Add poisson noise with gaussian blurred image to a numpy.ndarray (H x W x C) | |
""" | |
def __init__(self, sigma, peak, random_state=np.random): | |
self.sigma = sigma | |
self.peak = peak | |
self.random_state = random_state | |
def __call__(self, image): | |
if isinstance(self.sigma, collections.Sequence): | |
sigma = random_num_generator( | |
self.sigma, random_state=self.random_state) | |
else: | |
sigma = self.sigma | |
if isinstance(self.peak, collections.Sequence): | |
peak = random_num_generator( | |
self.peak, random_state=self.random_state) | |
else: | |
peak = self.peak | |
bg = gaussian_filter(image, sigma=(sigma, sigma, 0)) | |
bg = poisson_downsampling( | |
bg, peak=peak, random_state=self.random_state) | |
return image + bg | |
class MaxScaleNumpy(object): | |
"""scale with max and min of each channel of the numpy array i.e. | |
channel = (channel - mean) / std | |
""" | |
def __init__(self, range_min=0.0, range_max=1.0): | |
self.scale = (range_min, range_max) | |
def __call__(self, image): | |
mn = image.min(axis=(0, 1)) | |
mx = image.max(axis=(0, 1)) | |
return self.scale[0] + (image - mn) * (self.scale[1] - self.scale[0]) / (mx - mn) | |
class MedianScaleNumpy(object): | |
"""Scale with median and mean of each channel of the numpy array i.e. | |
channel = (channel - mean) / std | |
""" | |
def __init__(self, range_min=0.0, range_max=1.0): | |
self.scale = (range_min, range_max) | |
def __call__(self, image): | |
mn = image.min(axis=(0, 1)) | |
md = np.median(image, axis=(0, 1)) | |
return self.scale[0] + (image - mn) * (self.scale[1] - self.scale[0]) / (md - mn) | |
class NormalizeNumpy(object): | |
"""Normalize each channel of the numpy array i.e. | |
channel = (channel - mean) / std | |
""" | |
def __call__(self, image): | |
image -= image.mean(axis=(0, 1)) | |
s = image.std(axis=(0, 1)) | |
s[s == 0] = 1.0 | |
image /= s | |
return image | |
class MutualExclude(object): | |
"""Remove elements from one channel | |
""" | |
def __init__(self, exclude_channel, from_channel): | |
self.from_channel = from_channel | |
self.exclude_channel = exclude_channel | |
def __call__(self, image): | |
mask = image[:, :, self.exclude_channel] > 0 | |
image[:, :, self.from_channel][mask] = 0 | |
return image | |
class RandomCropNumpy(object): | |
"""Crops the given numpy array at a random location to have a region of | |
the given size. size can be a tuple (target_height, target_width) | |
or an integer, in which case the target will be of a square shape (size, size) | |
""" | |
def __init__(self, size, random_state=np.random): | |
if isinstance(size, numbers.Number): | |
self.size = (int(size), int(size)) | |
else: | |
self.size = size | |
self.random_state = random_state | |
def __call__(self, img): | |
w, h = img.shape[:2] | |
th, tw = self.size | |
if w == tw and h == th: | |
return img | |
x1 = self.random_state.randint(0, w - tw) | |
y1 = self.random_state.randint(0, h - th) | |
return img[x1:x1 + tw, y1: y1 + th, :] | |
class CenterCropNumpy(object): | |
"""Crops the given numpy array at the center to have a region of | |
the given size. size can be a tuple (target_height, target_width) | |
or an integer, in which case the target will be of a square shape (size, size) | |
""" | |
def __init__(self, size): | |
if isinstance(size, numbers.Number): | |
self.size = (int(size), int(size)) | |
else: | |
self.size = size | |
def __call__(self, img): | |
w, h = img.shape[:2] | |
th, tw = self.size | |
x1 = int(round((w - tw) / 2.)) | |
y1 = int(round((h - th) / 2.)) | |
return img[x1:x1 + tw, y1: y1 + th, :] | |
class RandomRotate(object): | |
"""Rotate a PIL.Image or numpy.ndarray (H x W x C) randomly | |
""" | |
def __init__(self, angle_range=(0.0, 360.0), axes=(0, 1), mode='reflect', random_state=np.random): | |
assert isinstance(angle_range, tuple) | |
self.angle_range = angle_range | |
self.random_state = random_state | |
self.axes = axes | |
self.mode = mode | |
def __call__(self, image): | |
angle = self.random_state.uniform( | |
self.angle_range[0], self.angle_range[1]) | |
if isinstance(image, np.ndarray): | |
mi, ma = image.min(), image.max() | |
image = scipy.ndimage.interpolation.rotate( | |
image, angle, reshape=False, axes=self.axes, mode=self.mode) | |
return np.clip(image, mi, ma) | |
elif isinstance(image, Image.Image): | |
return image.rotate(angle) | |
else: | |
raise Exception('unsupported type') | |
class BilinearResize(object): | |
"""Resize a PIL.Image or numpy.ndarray (H x W x C) | |
""" | |
def __init__(self, zoom): | |
self.zoom = [zoom, zoom, 1] | |
def __call__(self, image): | |
if isinstance(image, np.ndarray): | |
return scipy.ndimage.interpolation.zoom(image, self.zoom) | |
elif isinstance(image, Image.Image): | |
return image.resize(self.size, Image.BILINEAR) | |
else: | |
raise Exception('unsupported type') | |
class EnhancedCompose(object): | |
"""Composes several transforms together. | |
Args: | |
transforms (List[Transform]): list of transforms to compose. | |
Example: | |
>>> transforms.Compose([ | |
>>> transforms.CenterCrop(10), | |
>>> transforms.ToTensor(), | |
>>> ]) | |
""" | |
def __init__(self, transforms): | |
self.transforms = transforms | |
def __call__(self, img): | |
for t in self.transforms: | |
if isinstance(t, collections.Sequence): | |
assert isinstance(img, collections.Sequence) and len(img) == len( | |
t), "size of image group and transform group does not fit" | |
tmp_ = [] | |
for i, im_ in enumerate(img): | |
if callable(t[i]): | |
tmp_.append(t[i](im_)) | |
else: | |
tmp_.append(im_) | |
img = tmp_ | |
elif callable(t): | |
img = t(img) | |
elif t is None: | |
continue | |
else: | |
raise Exception('unexpected type') | |
return img | |
if __name__ == '__main__': | |
from torchvision.transforms import Lambda | |
input_channel = 3 | |
target_channel = 3 | |
# define a transform pipeline | |
transform = EnhancedCompose([ | |
Merge(), | |
RandomCropNumpy(size=(512, 512)), | |
RandomRotate(), | |
Split([0, input_channel], [input_channel, input_channel+target_channel]), | |
[CenterCropNumpy(size=(256, 256)), CenterCropNumpy(size=(256, 256))], | |
[NormalizeNumpy(), MaxScaleNumpy(0, 1.0)], | |
# for non-pytorch usage, remove to_tensor conversion | |
[Lambda(to_tensor), Lambda(to_tensor)] | |
]) | |
# read input data for test | |
image_in = np.array(Image.open('input.jpg')) | |
image_target = np.array(Image.open('target.jpg')) | |
# apply the transform | |
x, y = transform([image_in, image_target]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment