Created
June 4, 2018 13:27
-
-
Save jeasinema/ac4bac4ba07d10b4f8ec1c7f7a00dfb7 to your computer and use it in GitHub Desktop.
Generic augmentation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torchvision import transforms | |
import cv2 | |
import numpy as np | |
import types | |
from numpy import random | |
def intersect(box_a, box_b): | |
max_xy = np.minimum(box_a[:, 2:], box_b[2:]) | |
min_xy = np.maximum(box_a[:, :2], box_b[:2]) | |
inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) | |
return inter[:, 0] * inter[:, 1] | |
def jaccard_numpy(box_a, box_b): | |
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap | |
is simply the intersection over union of two boxes. | |
E.g.: | |
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) | |
Args: | |
box_a: Multiple bounding boxes, Shape: [num_boxes,4] | |
box_b: Single bounding box, Shape: [4] | |
Return: | |
jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] | |
""" | |
inter = intersect(box_a, box_b) | |
area_a = ((box_a[:, 2]-box_a[:, 0]) * | |
(box_a[:, 3]-box_a[:, 1])) # [A,B] | |
area_b = ((box_b[2]-box_b[0]) * | |
(box_b[3]-box_b[1])) # [A,B] | |
union = area_a + area_b - inter | |
return inter / union # [A,B] | |
class Compose(object): | |
"""Composes several augmentations together. | |
Args: | |
transforms (List[Transform]): list of transforms to compose. | |
Example: | |
>>> augmentations.Compose([ | |
>>> transforms.CenterCrop(10), | |
>>> transforms.ToTensor(), | |
>>> ]) | |
""" | |
def __init__(self, transforms): | |
self.transforms = transforms | |
def __call__(self, img, boxes=None, labels=None): | |
for t in self.transforms: | |
img, boxes, labels = t(img, boxes, labels) | |
return img, boxes, labels | |
class Lambda(object): | |
"""Applies a lambda as a transform.""" | |
def __init__(self, lambd): | |
assert isinstance(lambd, types.LambdaType) | |
self.lambd = lambd | |
def __call__(self, img, boxes=None, labels=None): | |
return self.lambd(img, boxes, labels) | |
class ConvertFromInts(object): | |
def __call__(self, image, boxes=None, labels=None): | |
return image.astype(np.float32), boxes, labels | |
class SubtractMeans(object): | |
def __init__(self, mean): | |
self.mean = np.array(mean, dtype=np.float32) | |
def __call__(self, image, boxes=None, labels=None): | |
image = image.astype(np.float32) | |
image -= self.mean | |
return image.astype(np.float32), boxes, labels | |
class ToAbsoluteCoords(object): | |
def __call__(self, image, boxes=None, labels=None): | |
height, width, channels = image.shape | |
boxes[:, 0] *= width | |
boxes[:, 2] *= width | |
boxes[:, 1] *= height | |
boxes[:, 3] *= height | |
return image, boxes, labels | |
class ToPercentCoords(object): | |
def __call__(self, image, boxes=None, labels=None): | |
height, width, channels = image.shape | |
boxes[:, 0] /= width | |
boxes[:, 2] /= width | |
boxes[:, 1] /= height | |
boxes[:, 3] /= height | |
return image, boxes, labels | |
class Resize(object): | |
def __init__(self, size=300): | |
self.size = size | |
def __call__(self, image, boxes=None, labels=None): | |
image = cv2.resize(image, (self.size, | |
self.size)) | |
return image, boxes, labels | |
class RandomSaturation(object): | |
def __init__(self, lower=0.5, upper=1.5): | |
self.lower = lower | |
self.upper = upper | |
assert self.upper >= self.lower, "contrast upper must be >= lower." | |
assert self.lower >= 0, "contrast lower must be non-negative." | |
def __call__(self, image, boxes=None, labels=None): | |
if random.randint(2): | |
image[:, :, 1] *= random.uniform(self.lower, self.upper) | |
return image, boxes, labels | |
class RandomHue(object): | |
def __init__(self, delta=18.0): | |
assert delta >= 0.0 and delta <= 360.0 | |
self.delta = delta | |
def __call__(self, image, boxes=None, labels=None): | |
if random.randint(2): | |
image[:, :, 0] += random.uniform(-self.delta, self.delta) | |
image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 | |
image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 | |
return image, boxes, labels | |
class RandomLightingNoise(object): | |
def __init__(self): | |
self.perms = ((0, 1, 2), (0, 2, 1), | |
(1, 0, 2), (1, 2, 0), | |
(2, 0, 1), (2, 1, 0)) | |
def __call__(self, image, boxes=None, labels=None): | |
if random.randint(2): | |
swap = self.perms[random.randint(len(self.perms))] | |
shuffle = SwapChannels(swap) # shuffle channels | |
image = shuffle(image) | |
return image, boxes, labels | |
class ConvertColor(object): | |
def __init__(self, current='BGR', transform='HSV'): | |
self.transform = transform | |
self.current = current | |
def __call__(self, image, boxes=None, labels=None): | |
if self.current == 'BGR' and self.transform == 'HSV': | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) | |
elif self.current == 'HSV' and self.transform == 'BGR': | |
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) | |
elif self.current == 'RGB' and self.transform == 'HSV': | |
image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) | |
elif self.current == 'HSV' and self.transform == 'RGB': | |
image = cv2.cvtColor(image, cv2.COLOR_HSV2RGB) | |
else: | |
raise NotImplementedError | |
return image, boxes, labels | |
class RandomContrast(object): | |
def __init__(self, lower=0.5, upper=1.5): | |
self.lower = lower | |
self.upper = upper | |
assert self.upper >= self.lower, "contrast upper must be >= lower." | |
assert self.lower >= 0, "contrast lower must be non-negative." | |
# expects float image | |
def __call__(self, image, boxes=None, labels=None): | |
if random.randint(2): | |
alpha = random.uniform(self.lower, self.upper) | |
image *= alpha | |
return image, boxes, labels | |
class RandomBrightness(object): | |
def __init__(self, delta=32): | |
assert delta >= 0.0 | |
assert delta <= 255.0 | |
self.delta = delta | |
def __call__(self, image, boxes=None, labels=None): | |
if random.randint(2): | |
delta = random.uniform(-self.delta, self.delta) | |
image += delta | |
return image, boxes, labels | |
class ToCV2Image(object): | |
def __call__(self, tensor, boxes=None, labels=None): | |
return tensor.cpu().numpy().astype(np.float32).transpose((1, 2, 0)), boxes, labels | |
class ToTensor(object): | |
def __call__(self, cvimage, boxes=None, labels=None): | |
return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels | |
class RandomSampleCrop(object): | |
"""Crop | |
Arguments: | |
img (Image): the image being input during training | |
boxes (Tensor): the original bounding boxes in pt form | |
labels (Tensor): the class labels for each bbox | |
mode (float tuple): the min and max jaccard overlaps | |
Return: | |
(img, boxes, classes) | |
img (Image): the cropped image | |
boxes (Tensor): the adjusted bounding boxes in pt form | |
labels (Tensor): the class labels for each bbox | |
""" | |
def __init__(self): | |
self.sample_options = ( | |
# using entire original input image | |
None, | |
# sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 | |
(0.1, None), | |
(0.3, None), | |
(0.7, None), | |
(0.9, None), | |
# randomly sample a patch | |
(None, None), | |
) | |
def __call__(self, image, boxes=None, labels=None): | |
height, width, _ = image.shape | |
while True: | |
# randomly choose a mode | |
mode = random.choice(self.sample_options) | |
if mode is None: | |
return image, boxes, labels | |
min_iou, max_iou = mode | |
if min_iou is None: | |
min_iou = float('-inf') | |
if max_iou is None: | |
max_iou = float('inf') | |
# max trails (50) | |
for _ in range(50): | |
current_image = image | |
w = random.uniform(0.3 * width, width) | |
h = random.uniform(0.3 * height, height) | |
# aspect ratio constraint b/t .5 & 2 | |
if h / w < 0.5 or h / w > 2: | |
continue | |
left = random.uniform(width - w) | |
top = random.uniform(height - h) | |
# convert to integer rect x1,y1,x2,y2 | |
rect = np.array([int(left), int(top), int(left+w), int(top+h)]) | |
# calculate IoU (jaccard overlap) b/t the cropped and gt boxes | |
overlap = jaccard_numpy(boxes, rect) | |
# is min and max overlap constraint satisfied? if not try again | |
if overlap.min() < min_iou and max_iou < overlap.max(): | |
continue | |
# cut the crop from the image | |
current_image = current_image[rect[1]:rect[3], rect[0]:rect[2], | |
:] | |
# keep overlap with gt box IF center in sampled patch | |
centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 | |
# mask in all gt boxes that above and to the left of centers | |
m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) | |
# mask in all gt boxes that under and to the right of centers | |
m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) | |
# mask in that both m1 and m2 are true | |
mask = m1 * m2 | |
# have any valid boxes? try again if not | |
if not mask.any(): | |
continue | |
# take only matching gt boxes | |
current_boxes = boxes[mask, :].copy() | |
# take only matching gt labels | |
current_labels = labels[mask] | |
# should we use the box left and top corner or the crop's | |
current_boxes[:, :2] = np.maximum(current_boxes[:, :2], | |
rect[:2]) | |
# adjust to crop (by substracting crop's left,top) | |
current_boxes[:, :2] -= rect[:2] | |
current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], | |
rect[2:]) | |
# adjust to crop (by substracting crop's left,top) | |
current_boxes[:, 2:] -= rect[:2] | |
return current_image, current_boxes, current_labels | |
class Expand(object): | |
def __init__(self, mean): | |
self.mean = mean | |
def __call__(self, image, boxes, labels): | |
if random.randint(2): | |
return image, boxes, labels | |
height, width, depth = image.shape | |
ratio = random.uniform(1, 4) | |
left = random.uniform(0, width*ratio - width) | |
top = random.uniform(0, height*ratio - height) | |
expand_image = np.zeros( | |
(int(height*ratio), int(width*ratio), depth), | |
dtype=image.dtype) | |
expand_image[:, :, :] = self.mean | |
expand_image[int(top):int(top + height), | |
int(left):int(left + width)] = image | |
image = expand_image | |
boxes = boxes.copy() | |
boxes[:, :2] += (int(left), int(top)) | |
boxes[:, 2:] += (int(left), int(top)) | |
return image, boxes, labels | |
class RandomMirror(object): | |
def __call__(self, image, boxes, classes): | |
_, width, _ = image.shape | |
if random.randint(2): | |
image = image[:, ::-1] | |
boxes = boxes.copy() | |
boxes[:, 0::2] = width - boxes[:, 2::-2] | |
return image, boxes, classes | |
class SwapChannels(object): | |
"""Transforms a tensorized image by swapping the channels in the order | |
specified in the swap tuple. | |
Args: | |
swaps (int triple): final order of channels | |
eg: (2, 1, 0) | |
""" | |
def __init__(self, swaps): | |
self.swaps = swaps | |
def __call__(self, image): | |
""" | |
Args: | |
image (Tensor): image tensor to be transformed | |
Return: | |
a tensor with channels swapped according to swap | |
""" | |
# if torch.is_tensor(image): | |
# image = image.data.cpu().numpy() | |
# else: | |
# image = np.array(image) | |
image = image[:, :, self.swaps] | |
return image | |
class RandomGrayscale(object): | |
def __init__(self, p=0.1, current='RGB'): | |
self.p = p | |
self.current = current | |
def __call__(self, image, boxes=None, labels=None): | |
orig_t = image.dtype | |
if np.random.randint(10) < 10*self.p: | |
if self.current == 'RGB': | |
image = np.repeat(cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_RGB2GRAY)[..., np.newaxis], 3, -1) | |
elif self.current == 'BGR': | |
image = np.repeat(cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_BGR2GRAY)[..., np.newaxis], 3, -1) | |
else: | |
raise NotImplementedError | |
return image.astype(orig_t), boxes, labels | |
class PhotometricDistort(object): | |
def __init__(self, current='BGR'): | |
self.pd = [ | |
RandomContrast(), | |
ConvertColor(current=current, transform='HSV'), | |
RandomSaturation(), | |
RandomHue(), | |
ConvertColor(current='HSV', transform=current), | |
RandomContrast() | |
] | |
self.rand_brightness = RandomBrightness() | |
self.rand_light_noise = RandomLightingNoise() | |
def __call__(self, image, boxes, labels): | |
im = image.copy() | |
im, boxes, labels = self.rand_brightness(im, boxes, labels) | |
if random.randint(2): | |
distort = Compose(self.pd[:-1]) | |
else: | |
distort = Compose(self.pd[1:]) | |
im, boxes, labels = distort(im, boxes, labels) | |
return self.rand_light_noise(im, boxes, labels) | |
class SSDAugmentation(object): | |
def __init__(self, size=300, mean=(104, 117, 123)): | |
self.mean = mean | |
self.size = size | |
self.augment = Compose([ | |
ConvertFromInts(), | |
ToAbsoluteCoords(), | |
PhotometricDistort(), | |
Expand(self.mean), | |
RandomSampleCrop(), | |
RandomMirror(), | |
ToPercentCoords(), | |
Resize(self.size), | |
SubtractMeans(self.mean) | |
]) | |
def __call__(self, img, boxes, labels): | |
return self.augment(img, boxes, labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment