Skip to content

Instantly share code, notes, and snippets.

@stephenyan1231
Created February 21, 2021 20:42
Show Gist options
  • Select an option

  • Save stephenyan1231/91091040c726c3fa630e8a3e22128f21 to your computer and use it in GitHub Desktop.

Select an option

Save stephenyan1231/91091040c726c3fa630e8a3e22128f21 to your computer and use it in GitHub Desktop.
deeplearning/projects/classy_vision/fb/dataset/transforms/autoaug_video.py
#!/usr/bin/env python3
# (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
# Copied from D25942231, D22269078 but MODIFIED FOR VIDEO,
# and referred D24414029
""" Auto Augment
Implementation adapted from:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172
Note that the TF auto-augmentation is different from the PyTorch auto-augmentation
TF's implementation is also different from the original paper
Hacked together by Ross Wightman (https://github.com/rwightman) and MODIFIED FOR VIDEO
"""
import math
import random
import numpy as np
import PIL
import torch
from classy_vision.dataset.transforms import ClassyTransform, register_transform
from PIL import Image, ImageEnhance, ImageOps
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
_pil_interpolation_to_str = {
Image.NEAREST: "PIL.Image.NEAREST",
Image.BILINEAR: "PIL.Image.BILINEAR",
Image.BICUBIC: "PIL.Image.BICUBIC",
Image.LANCZOS: "PIL.Image.LANCZOS",
Image.HAMMING: "PIL.Image.HAMMING",
Image.BOX: "PIL.Image.BOX",
}
def _pil_interp(method):
if method == "bicubic":
return Image.BICUBIC
elif method == "lanczos":
return Image.LANCZOS
elif method == "hamming":
return Image.HAMMING
else:
# default bilinear, do we want to allow nearest?
return Image.BILINEAR
_PIL_VER = tuple(int(x) for x in PIL.__version__.split(".")[:2])
_FILL = (128, 128, 128)
# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
_MAX_LEVEL = 10.0
_HPARAMS_DEFAULT = {"translate_const": 250, "img_mean": _FILL}
_RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC)
def _interpolation(kwargs):
interpolation = kwargs.pop("resample", Image.NEAREST)
if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation)
else:
return interpolation
def _check_args_tf(kwargs):
if "fillcolor" in kwargs and _PIL_VER < (5, 0):
kwargs.pop("fillcolor")
kwargs["resample"] = _interpolation(kwargs)
def shear_x(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
def shear_y(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
def translate_x_rel(img, pct, **kwargs):
pixels = pct * img.size[0]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
def translate_y_rel(img, pct, **kwargs):
pixels = pct * img.size[1]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
def translate_x_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
def translate_y_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
def rotate(img, degrees, **kwargs):
_check_args_tf(kwargs)
if _PIL_VER >= (5, 2):
return img.rotate(degrees, **kwargs)
elif _PIL_VER >= (5, 0):
w, h = img.size
post_trans = (0, 0)
rotn_center = (w / 2.0, h / 2.0)
angle = -(math.radians(degrees))
matrix = [
round(math.cos(angle), 15),
round(math.sin(angle), 15),
0.0,
round(-(math.sin(angle)), 15),
round(math.cos(angle), 15),
0.0,
]
def transform(x, y, matrix):
(a, b, c, d, e, f) = matrix
return a * x + b * y + c, d * x + e * y + f
matrix[2], matrix[5] = transform(
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
)
matrix[2] += rotn_center[0]
matrix[5] += rotn_center[1]
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
else:
return img.rotate(degrees, resample=kwargs["resample"])
def auto_contrast(img, **__):
return ImageOps.autocontrast(img)
def invert(img, **__):
return ImageOps.invert(img)
def equalize(img, **__):
return ImageOps.equalize(img)
def solarize(img, thresh, **__):
return ImageOps.solarize(img, thresh)
def solarize_add(img, add, thresh=128, **__):
lut = []
for i in range(256):
if i < thresh:
lut.append(min(255, i + add))
else:
lut.append(i)
if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut
return img.point(lut)
else:
return img
def posterize(img, bits_to_keep, **__):
if bits_to_keep >= 8:
return img
bits_to_keep = max(1, bits_to_keep) # prevent all 0 images
return ImageOps.posterize(img, bits_to_keep)
def contrast(img, factor, **__):
return ImageEnhance.Contrast(img).enhance(factor)
def color(img, factor, **__):
return ImageEnhance.Color(img).enhance(factor)
def brightness(img, factor, **__):
return ImageEnhance.Brightness(img).enhance(factor)
def sharpness(img, factor, **__):
return ImageEnhance.Sharpness(img).enhance(factor)
def _randomly_negate(v):
"""With 50% prob, negate the value"""
return -v if random.random() > 0.5 else v
def _rotate_level_to_arg(level):
# range [-30, 30]
level = (level / _MAX_LEVEL) * 30.0
level = _randomly_negate(level)
return (level,)
def _enhance_level_to_arg(level):
# range [0.1, 1.9]
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
def _shear_level_to_arg(level):
# range [-0.3, 0.3]
level = (level / _MAX_LEVEL) * 0.3
level = _randomly_negate(level)
return (level,)
def _translate_abs_level_to_arg(level, translate_const):
level = (level / _MAX_LEVEL) * float(translate_const)
level = _randomly_negate(level)
return (level,)
def _translate_abs_level_to_arg2(level):
level = (level / _MAX_LEVEL) * float(_HPARAMS_DEFAULT["translate_const"])
level = _randomly_negate(level)
return (level,)
def _translate_rel_level_to_arg(level):
# range [-0.45, 0.45]
level = (level / _MAX_LEVEL) * 0.45
level = _randomly_negate(level)
return (level,)
NAME_TO_OP = {
"AutoContrast": auto_contrast,
"Equalize": equalize,
"Invert": invert,
"Rotate": rotate,
"Posterize": posterize,
"Posterize2": posterize,
"Solarize": solarize,
"SolarizeAdd": solarize_add,
"Color": color,
"Contrast": contrast,
"Brightness": brightness,
"Sharpness": sharpness,
"ShearX": shear_x,
"ShearY": shear_y,
"TranslateX": translate_x_abs,
"TranslateY": translate_y_abs,
"TranslateXRel": translate_x_rel,
"TranslateYRel": translate_y_rel,
}
def pass_fn(input):
return ()
def _conversion0(input):
return (int((input / _MAX_LEVEL) * 4) + 4,)
def _conversion1(input):
return (4 - int((input / _MAX_LEVEL) * 4),)
def _conversion2(input):
return (int((input / _MAX_LEVEL) * 256),)
def _conversion3(input):
return (int((input / _MAX_LEVEL) * 110),)
class AutoAugmentOp:
def __init__(self, name, prob, magnitude, hparams=None):
hparams = hparams or {}
self.aug_fn = NAME_TO_OP[name]
# self.level_fn = level_to_arg(hparams)[name]
if name == "AutoContrast" or name == "Equalize" or name == "Invert":
self.level_fn = pass_fn
elif name == "Rotate":
self.level_fn = _rotate_level_to_arg
elif name == "Posterize":
self.level_fn = _conversion0
elif name == "Posterize2":
self.level_fn = _conversion1
elif name == "Solarize":
self.level_fn = _conversion2
elif name == "SolarizeAdd":
self.level_fn = _conversion3
elif (
name == "Color"
or name == "Contrast"
or name == "Brightness"
or name == "Sharpness"
):
self.level_fn = _enhance_level_to_arg
elif name == "ShearX" or name == "ShearY":
self.level_fn = _shear_level_to_arg
elif name == "TranslateX" or name == "TranslateY":
self.level_fn = _translate_abs_level_to_arg2
elif name == "TranslateXRel" or name == "TranslateYRel":
self.level_fn = _translate_rel_level_to_arg
else:
print("{} not recognized".format({}))
self.prob = prob
self.magnitude = magnitude
# If std deviation of magnitude is > 0, we introduce some randomness
# in the usually fixed policy and sample magnitude from normal dist
# with mean magnitude and std-dev of magnitude_std.
# NOTE This is being tested as it's not in paper or reference impl.
self.magnitude_std = 0.5 # FIXME add arg/hparam
self.kwargs = {
"fillcolor": hparams.get("img_mean", _FILL),
"resample": hparams.get("interpolation", _RANDOM_INTERPOLATION),
}
def __call__(self, img):
if self.prob < random.random():
return img
magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std)
magnitude = min(_MAX_LEVEL, max(0, magnitude))
level_args = self.level_fn(magnitude)
return self.aug_fn(img, *level_args, **self.kwargs)
class AutoAugmentVideoOp(AutoAugmentOp):
def __call__(self, vid):
"""
vid (numpy, 255): thwc
output (numpy, 255): thwc
"""
if self.prob < random.random():
return vid
magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std)
magnitude = min(_MAX_LEVEL, max(0, magnitude))
level_args = self.level_fn(magnitude)
frames_hwc = []
for img_hwc in vid: # Iterate over frames # TODO better way
img_hwc = Image.fromarray(img_hwc)
img_hwc = self.aug_fn(img_hwc, *level_args, **self.kwargs)
frames_hwc.append(np.array(img_hwc))
return np.stack(frames_hwc, axis=0)
def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT):
# ImageNet policy from TPU EfficientNet impl, cannot find
# a paper reference.
policy = [
[("Equalize", 0.8, 1), ("ShearY", 0.8, 4)],
[("Color", 0.4, 9), ("Equalize", 0.6, 3)],
[("Color", 0.4, 1), ("Rotate", 0.6, 8)],
[("Solarize", 0.8, 3), ("Equalize", 0.4, 7)],
[("Solarize", 0.4, 2), ("Solarize", 0.6, 2)],
[("Color", 0.2, 0), ("Equalize", 0.8, 8)],
[("Equalize", 0.4, 8), ("SolarizeAdd", 0.8, 3)],
[("ShearX", 0.2, 9), ("Rotate", 0.6, 8)],
[("Color", 0.6, 1), ("Equalize", 1.0, 2)],
[("Invert", 0.4, 9), ("Rotate", 0.6, 0)],
[("Equalize", 1.0, 9), ("ShearY", 0.6, 3)],
[("Color", 0.4, 7), ("Equalize", 0.6, 0)],
[("Posterize", 0.4, 6), ("AutoContrast", 0.4, 7)],
[("Solarize", 0.6, 8), ("Color", 0.6, 9)],
[("Solarize", 0.2, 4), ("Rotate", 0.8, 9)],
[("Rotate", 1.0, 7), ("TranslateYRel", 0.8, 9)],
[("ShearX", 0.0, 0), ("Solarize", 0.8, 4)],
[("ShearY", 0.8, 0), ("Color", 0.6, 4)],
[("Color", 1.0, 0), ("Rotate", 0.6, 2)],
[("Equalize", 0.8, 4), ("Equalize", 0.0, 8)],
[("Equalize", 1.0, 4), ("AutoContrast", 0.6, 2)],
[("ShearY", 0.4, 7), ("SolarizeAdd", 0.6, 7)],
[("Posterize", 0.8, 2), ("Solarize", 0.6, 10)],
[("Solarize", 0.6, 8), ("Equalize", 0.6, 1)],
[("Color", 0.8, 6), ("Rotate", 0.4, 5)],
]
pc = [[AutoAugmentVideoOp(*a, hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT):
# ImageNet policy from https://arxiv.org/abs/1805.09501
policy = [
[("Posterize", 0.4, 8), ("Rotate", 0.6, 9)],
[("Solarize", 0.6, 5), ("AutoContrast", 0.6, 5)],
[("Equalize", 0.8, 8), ("Equalize", 0.6, 3)],
[("Posterize", 0.6, 7), ("Posterize", 0.6, 6)],
[("Equalize", 0.4, 7), ("Solarize", 0.2, 4)],
[("Equalize", 0.4, 4), ("Rotate", 0.8, 8)],
[("Solarize", 0.6, 3), ("Equalize", 0.6, 7)],
[("Posterize", 0.8, 5), ("Equalize", 1.0, 2)],
[("Rotate", 0.2, 3), ("Solarize", 0.6, 8)],
[("Equalize", 0.6, 8), ("Posterize", 0.4, 6)],
[("Rotate", 0.8, 8), ("Color", 0.4, 0)],
[("Rotate", 0.4, 9), ("Equalize", 0.6, 2)],
[("Equalize", 0.0, 7), ("Equalize", 0.8, 8)],
[("Invert", 0.6, 4), ("Equalize", 1.0, 8)],
[("Color", 0.6, 4), ("Contrast", 1.0, 8)],
[("Rotate", 0.8, 8), ("Color", 1.0, 2)],
[("Color", 0.8, 8), ("Solarize", 0.8, 7)],
[("Sharpness", 0.4, 7), ("Invert", 0.6, 8)],
[("ShearX", 0.6, 5), ("Equalize", 1.0, 9)],
[("Color", 0.4, 0), ("Equalize", 0.6, 3)],
[("Equalize", 0.4, 7), ("Solarize", 0.2, 4)],
[("Solarize", 0.6, 5), ("AutoContrast", 0.6, 5)],
[("Invert", 0.6, 4), ("Equalize", 1.0, 8)],
[("Color", 0.6, 4), ("Contrast", 1.0, 8)],
[("Equalize", 0.8, 8), ("Equalize", 0.6, 3)],
]
pc = [[AutoAugmentVideoOp(*a, hparams) for a in sp] for sp in policy]
return pc
def auto_augment_policy(name="v0", hparams=_HPARAMS_DEFAULT):
if name == "original":
return auto_augment_policy_original(hparams)
elif name == "v0":
return auto_augment_policy_v0(hparams)
else:
print("Unknown auto_augmentation policy {}".format(name))
raise AssertionError()
def to_tensor_thwc_uint8(clip: torch.Tensor):
"""
Inverse of torchvision.transforms._functional_video.to_tensor()
Convert video clip in CTHW float32 (0~1.0) format to THWC uint8 (0~255)
Args:
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
Return:
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
"""
assert isinstance(clip, torch.Tensor)
assert len(clip.shape) == 4
assert clip.dtype == torch.float32
return (clip * 255.0).permute(1, 2, 3, 0).byte()
@register_transform("ToTensorTHWCUInt8")
class ToTensorTHWCUInt8(ClassyTransform):
def __call__(self, clip: torch.Tensor):
return to_tensor_thwc_uint8(clip)
@register_transform("VideoAutoAugmentTF")
class VideoAutoAugmentTF(ClassyTransform):
def __init__(
self,
policy_name,
img_size,
pixel_mean,
interpolation,
):
"""
pixel_mean: list of three values of 0~1
"""
if isinstance(img_size, tuple):
img_size_min = min(img_size)
else:
img_size_min = img_size
scaled_mean = [x * 255.0 for x in pixel_mean]
aa_params = {
"translate_const": int(img_size_min * 0.45),
"img_mean": tuple(round(x) for x in scaled_mean),
}
if interpolation != "random":
aa_params["interpolation"] = _pil_interp(interpolation)
self.policy = auto_augment_policy(policy_name, aa_params)
def __call__(self, vid):
sub_policy = random.choice(self.policy)
assert isinstance(vid, torch.Tensor), "Wrong type of input"
if vid.dtype != torch.uint8:
vid = vid * 255.0
vid = vid.numpy().astype(np.uint8)
for op in sub_policy:
vid = op(vid)
# Return to Tensor uint8, range of 0~255
vid = torch.from_numpy(vid).type(torch.uint8)
return vid
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment