Created
February 21, 2021 20:42
-
-
Save stephenyan1231/91091040c726c3fa630e8a3e22128f21 to your computer and use it in GitHub Desktop.
deeplearning/projects/classy_vision/fb/dataset/transforms/autoaug_video.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # (c) Facebook, Inc. and its affiliates. Confidential and proprietary. | |
| # Copied from D25942231, D22269078 but MODIFIED FOR VIDEO, | |
| # and referred D24414029 | |
| """ Auto Augment | |
| Implementation adapted from: | |
| https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py | |
| Papers: https://arxiv.org/abs/1805.09501 and https://arxiv.org/abs/1906.11172 | |
| Note that the TF auto-augmentation is different from the PyTorch auto-augmentation | |
| TF's implementation is also different from the original paper | |
| Hacked together by Ross Wightman (https://github.com/rwightman) and MODIFIED FOR VIDEO | |
| """ | |
| import math | |
| import random | |
| import numpy as np | |
| import PIL | |
| import torch | |
| from classy_vision.dataset.transforms import ClassyTransform, register_transform | |
| from PIL import Image, ImageEnhance, ImageOps | |
| _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) | |
| _pil_interpolation_to_str = { | |
| Image.NEAREST: "PIL.Image.NEAREST", | |
| Image.BILINEAR: "PIL.Image.BILINEAR", | |
| Image.BICUBIC: "PIL.Image.BICUBIC", | |
| Image.LANCZOS: "PIL.Image.LANCZOS", | |
| Image.HAMMING: "PIL.Image.HAMMING", | |
| Image.BOX: "PIL.Image.BOX", | |
| } | |
| def _pil_interp(method): | |
| if method == "bicubic": | |
| return Image.BICUBIC | |
| elif method == "lanczos": | |
| return Image.LANCZOS | |
| elif method == "hamming": | |
| return Image.HAMMING | |
| else: | |
| # default bilinear, do we want to allow nearest? | |
| return Image.BILINEAR | |
| _PIL_VER = tuple(int(x) for x in PIL.__version__.split(".")[:2]) | |
| _FILL = (128, 128, 128) | |
| # This signifies the max integer that the controller RNN could predict for the | |
| # augmentation scheme. | |
| _MAX_LEVEL = 10.0 | |
| _HPARAMS_DEFAULT = {"translate_const": 250, "img_mean": _FILL} | |
| _RANDOM_INTERPOLATION = (Image.NEAREST, Image.BILINEAR, Image.BICUBIC) | |
| def _interpolation(kwargs): | |
| interpolation = kwargs.pop("resample", Image.NEAREST) | |
| if isinstance(interpolation, (list, tuple)): | |
| return random.choice(interpolation) | |
| else: | |
| return interpolation | |
| def _check_args_tf(kwargs): | |
| if "fillcolor" in kwargs and _PIL_VER < (5, 0): | |
| kwargs.pop("fillcolor") | |
| kwargs["resample"] = _interpolation(kwargs) | |
| def shear_x(img, factor, **kwargs): | |
| _check_args_tf(kwargs) | |
| return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) | |
| def shear_y(img, factor, **kwargs): | |
| _check_args_tf(kwargs) | |
| return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) | |
| def translate_x_rel(img, pct, **kwargs): | |
| pixels = pct * img.size[0] | |
| _check_args_tf(kwargs) | |
| return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) | |
| def translate_y_rel(img, pct, **kwargs): | |
| pixels = pct * img.size[1] | |
| _check_args_tf(kwargs) | |
| return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) | |
| def translate_x_abs(img, pixels, **kwargs): | |
| _check_args_tf(kwargs) | |
| return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) | |
| def translate_y_abs(img, pixels, **kwargs): | |
| _check_args_tf(kwargs) | |
| return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) | |
| def rotate(img, degrees, **kwargs): | |
| _check_args_tf(kwargs) | |
| if _PIL_VER >= (5, 2): | |
| return img.rotate(degrees, **kwargs) | |
| elif _PIL_VER >= (5, 0): | |
| w, h = img.size | |
| post_trans = (0, 0) | |
| rotn_center = (w / 2.0, h / 2.0) | |
| angle = -(math.radians(degrees)) | |
| matrix = [ | |
| round(math.cos(angle), 15), | |
| round(math.sin(angle), 15), | |
| 0.0, | |
| round(-(math.sin(angle)), 15), | |
| round(math.cos(angle), 15), | |
| 0.0, | |
| ] | |
| def transform(x, y, matrix): | |
| (a, b, c, d, e, f) = matrix | |
| return a * x + b * y + c, d * x + e * y + f | |
| matrix[2], matrix[5] = transform( | |
| -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix | |
| ) | |
| matrix[2] += rotn_center[0] | |
| matrix[5] += rotn_center[1] | |
| return img.transform(img.size, Image.AFFINE, matrix, **kwargs) | |
| else: | |
| return img.rotate(degrees, resample=kwargs["resample"]) | |
| def auto_contrast(img, **__): | |
| return ImageOps.autocontrast(img) | |
| def invert(img, **__): | |
| return ImageOps.invert(img) | |
| def equalize(img, **__): | |
| return ImageOps.equalize(img) | |
| def solarize(img, thresh, **__): | |
| return ImageOps.solarize(img, thresh) | |
| def solarize_add(img, add, thresh=128, **__): | |
| lut = [] | |
| for i in range(256): | |
| if i < thresh: | |
| lut.append(min(255, i + add)) | |
| else: | |
| lut.append(i) | |
| if img.mode in ("L", "RGB"): | |
| if img.mode == "RGB" and len(lut) == 256: | |
| lut = lut + lut + lut | |
| return img.point(lut) | |
| else: | |
| return img | |
| def posterize(img, bits_to_keep, **__): | |
| if bits_to_keep >= 8: | |
| return img | |
| bits_to_keep = max(1, bits_to_keep) # prevent all 0 images | |
| return ImageOps.posterize(img, bits_to_keep) | |
| def contrast(img, factor, **__): | |
| return ImageEnhance.Contrast(img).enhance(factor) | |
| def color(img, factor, **__): | |
| return ImageEnhance.Color(img).enhance(factor) | |
| def brightness(img, factor, **__): | |
| return ImageEnhance.Brightness(img).enhance(factor) | |
| def sharpness(img, factor, **__): | |
| return ImageEnhance.Sharpness(img).enhance(factor) | |
| def _randomly_negate(v): | |
| """With 50% prob, negate the value""" | |
| return -v if random.random() > 0.5 else v | |
| def _rotate_level_to_arg(level): | |
| # range [-30, 30] | |
| level = (level / _MAX_LEVEL) * 30.0 | |
| level = _randomly_negate(level) | |
| return (level,) | |
| def _enhance_level_to_arg(level): | |
| # range [0.1, 1.9] | |
| return ((level / _MAX_LEVEL) * 1.8 + 0.1,) | |
| def _shear_level_to_arg(level): | |
| # range [-0.3, 0.3] | |
| level = (level / _MAX_LEVEL) * 0.3 | |
| level = _randomly_negate(level) | |
| return (level,) | |
| def _translate_abs_level_to_arg(level, translate_const): | |
| level = (level / _MAX_LEVEL) * float(translate_const) | |
| level = _randomly_negate(level) | |
| return (level,) | |
| def _translate_abs_level_to_arg2(level): | |
| level = (level / _MAX_LEVEL) * float(_HPARAMS_DEFAULT["translate_const"]) | |
| level = _randomly_negate(level) | |
| return (level,) | |
| def _translate_rel_level_to_arg(level): | |
| # range [-0.45, 0.45] | |
| level = (level / _MAX_LEVEL) * 0.45 | |
| level = _randomly_negate(level) | |
| return (level,) | |
| NAME_TO_OP = { | |
| "AutoContrast": auto_contrast, | |
| "Equalize": equalize, | |
| "Invert": invert, | |
| "Rotate": rotate, | |
| "Posterize": posterize, | |
| "Posterize2": posterize, | |
| "Solarize": solarize, | |
| "SolarizeAdd": solarize_add, | |
| "Color": color, | |
| "Contrast": contrast, | |
| "Brightness": brightness, | |
| "Sharpness": sharpness, | |
| "ShearX": shear_x, | |
| "ShearY": shear_y, | |
| "TranslateX": translate_x_abs, | |
| "TranslateY": translate_y_abs, | |
| "TranslateXRel": translate_x_rel, | |
| "TranslateYRel": translate_y_rel, | |
| } | |
| def pass_fn(input): | |
| return () | |
| def _conversion0(input): | |
| return (int((input / _MAX_LEVEL) * 4) + 4,) | |
| def _conversion1(input): | |
| return (4 - int((input / _MAX_LEVEL) * 4),) | |
| def _conversion2(input): | |
| return (int((input / _MAX_LEVEL) * 256),) | |
| def _conversion3(input): | |
| return (int((input / _MAX_LEVEL) * 110),) | |
| class AutoAugmentOp: | |
| def __init__(self, name, prob, magnitude, hparams=None): | |
| hparams = hparams or {} | |
| self.aug_fn = NAME_TO_OP[name] | |
| # self.level_fn = level_to_arg(hparams)[name] | |
| if name == "AutoContrast" or name == "Equalize" or name == "Invert": | |
| self.level_fn = pass_fn | |
| elif name == "Rotate": | |
| self.level_fn = _rotate_level_to_arg | |
| elif name == "Posterize": | |
| self.level_fn = _conversion0 | |
| elif name == "Posterize2": | |
| self.level_fn = _conversion1 | |
| elif name == "Solarize": | |
| self.level_fn = _conversion2 | |
| elif name == "SolarizeAdd": | |
| self.level_fn = _conversion3 | |
| elif ( | |
| name == "Color" | |
| or name == "Contrast" | |
| or name == "Brightness" | |
| or name == "Sharpness" | |
| ): | |
| self.level_fn = _enhance_level_to_arg | |
| elif name == "ShearX" or name == "ShearY": | |
| self.level_fn = _shear_level_to_arg | |
| elif name == "TranslateX" or name == "TranslateY": | |
| self.level_fn = _translate_abs_level_to_arg2 | |
| elif name == "TranslateXRel" or name == "TranslateYRel": | |
| self.level_fn = _translate_rel_level_to_arg | |
| else: | |
| print("{} not recognized".format({})) | |
| self.prob = prob | |
| self.magnitude = magnitude | |
| # If std deviation of magnitude is > 0, we introduce some randomness | |
| # in the usually fixed policy and sample magnitude from normal dist | |
| # with mean magnitude and std-dev of magnitude_std. | |
| # NOTE This is being tested as it's not in paper or reference impl. | |
| self.magnitude_std = 0.5 # FIXME add arg/hparam | |
| self.kwargs = { | |
| "fillcolor": hparams.get("img_mean", _FILL), | |
| "resample": hparams.get("interpolation", _RANDOM_INTERPOLATION), | |
| } | |
| def __call__(self, img): | |
| if self.prob < random.random(): | |
| return img | |
| magnitude = self.magnitude | |
| if self.magnitude_std and self.magnitude_std > 0: | |
| magnitude = random.gauss(magnitude, self.magnitude_std) | |
| magnitude = min(_MAX_LEVEL, max(0, magnitude)) | |
| level_args = self.level_fn(magnitude) | |
| return self.aug_fn(img, *level_args, **self.kwargs) | |
| class AutoAugmentVideoOp(AutoAugmentOp): | |
| def __call__(self, vid): | |
| """ | |
| vid (numpy, 255): thwc | |
| output (numpy, 255): thwc | |
| """ | |
| if self.prob < random.random(): | |
| return vid | |
| magnitude = self.magnitude | |
| if self.magnitude_std and self.magnitude_std > 0: | |
| magnitude = random.gauss(magnitude, self.magnitude_std) | |
| magnitude = min(_MAX_LEVEL, max(0, magnitude)) | |
| level_args = self.level_fn(magnitude) | |
| frames_hwc = [] | |
| for img_hwc in vid: # Iterate over frames # TODO better way | |
| img_hwc = Image.fromarray(img_hwc) | |
| img_hwc = self.aug_fn(img_hwc, *level_args, **self.kwargs) | |
| frames_hwc.append(np.array(img_hwc)) | |
| return np.stack(frames_hwc, axis=0) | |
| def auto_augment_policy_v0(hparams=_HPARAMS_DEFAULT): | |
| # ImageNet policy from TPU EfficientNet impl, cannot find | |
| # a paper reference. | |
| policy = [ | |
| [("Equalize", 0.8, 1), ("ShearY", 0.8, 4)], | |
| [("Color", 0.4, 9), ("Equalize", 0.6, 3)], | |
| [("Color", 0.4, 1), ("Rotate", 0.6, 8)], | |
| [("Solarize", 0.8, 3), ("Equalize", 0.4, 7)], | |
| [("Solarize", 0.4, 2), ("Solarize", 0.6, 2)], | |
| [("Color", 0.2, 0), ("Equalize", 0.8, 8)], | |
| [("Equalize", 0.4, 8), ("SolarizeAdd", 0.8, 3)], | |
| [("ShearX", 0.2, 9), ("Rotate", 0.6, 8)], | |
| [("Color", 0.6, 1), ("Equalize", 1.0, 2)], | |
| [("Invert", 0.4, 9), ("Rotate", 0.6, 0)], | |
| [("Equalize", 1.0, 9), ("ShearY", 0.6, 3)], | |
| [("Color", 0.4, 7), ("Equalize", 0.6, 0)], | |
| [("Posterize", 0.4, 6), ("AutoContrast", 0.4, 7)], | |
| [("Solarize", 0.6, 8), ("Color", 0.6, 9)], | |
| [("Solarize", 0.2, 4), ("Rotate", 0.8, 9)], | |
| [("Rotate", 1.0, 7), ("TranslateYRel", 0.8, 9)], | |
| [("ShearX", 0.0, 0), ("Solarize", 0.8, 4)], | |
| [("ShearY", 0.8, 0), ("Color", 0.6, 4)], | |
| [("Color", 1.0, 0), ("Rotate", 0.6, 2)], | |
| [("Equalize", 0.8, 4), ("Equalize", 0.0, 8)], | |
| [("Equalize", 1.0, 4), ("AutoContrast", 0.6, 2)], | |
| [("ShearY", 0.4, 7), ("SolarizeAdd", 0.6, 7)], | |
| [("Posterize", 0.8, 2), ("Solarize", 0.6, 10)], | |
| [("Solarize", 0.6, 8), ("Equalize", 0.6, 1)], | |
| [("Color", 0.8, 6), ("Rotate", 0.4, 5)], | |
| ] | |
| pc = [[AutoAugmentVideoOp(*a, hparams) for a in sp] for sp in policy] | |
| return pc | |
| def auto_augment_policy_original(hparams=_HPARAMS_DEFAULT): | |
| # ImageNet policy from https://arxiv.org/abs/1805.09501 | |
| policy = [ | |
| [("Posterize", 0.4, 8), ("Rotate", 0.6, 9)], | |
| [("Solarize", 0.6, 5), ("AutoContrast", 0.6, 5)], | |
| [("Equalize", 0.8, 8), ("Equalize", 0.6, 3)], | |
| [("Posterize", 0.6, 7), ("Posterize", 0.6, 6)], | |
| [("Equalize", 0.4, 7), ("Solarize", 0.2, 4)], | |
| [("Equalize", 0.4, 4), ("Rotate", 0.8, 8)], | |
| [("Solarize", 0.6, 3), ("Equalize", 0.6, 7)], | |
| [("Posterize", 0.8, 5), ("Equalize", 1.0, 2)], | |
| [("Rotate", 0.2, 3), ("Solarize", 0.6, 8)], | |
| [("Equalize", 0.6, 8), ("Posterize", 0.4, 6)], | |
| [("Rotate", 0.8, 8), ("Color", 0.4, 0)], | |
| [("Rotate", 0.4, 9), ("Equalize", 0.6, 2)], | |
| [("Equalize", 0.0, 7), ("Equalize", 0.8, 8)], | |
| [("Invert", 0.6, 4), ("Equalize", 1.0, 8)], | |
| [("Color", 0.6, 4), ("Contrast", 1.0, 8)], | |
| [("Rotate", 0.8, 8), ("Color", 1.0, 2)], | |
| [("Color", 0.8, 8), ("Solarize", 0.8, 7)], | |
| [("Sharpness", 0.4, 7), ("Invert", 0.6, 8)], | |
| [("ShearX", 0.6, 5), ("Equalize", 1.0, 9)], | |
| [("Color", 0.4, 0), ("Equalize", 0.6, 3)], | |
| [("Equalize", 0.4, 7), ("Solarize", 0.2, 4)], | |
| [("Solarize", 0.6, 5), ("AutoContrast", 0.6, 5)], | |
| [("Invert", 0.6, 4), ("Equalize", 1.0, 8)], | |
| [("Color", 0.6, 4), ("Contrast", 1.0, 8)], | |
| [("Equalize", 0.8, 8), ("Equalize", 0.6, 3)], | |
| ] | |
| pc = [[AutoAugmentVideoOp(*a, hparams) for a in sp] for sp in policy] | |
| return pc | |
| def auto_augment_policy(name="v0", hparams=_HPARAMS_DEFAULT): | |
| if name == "original": | |
| return auto_augment_policy_original(hparams) | |
| elif name == "v0": | |
| return auto_augment_policy_v0(hparams) | |
| else: | |
| print("Unknown auto_augmentation policy {}".format(name)) | |
| raise AssertionError() | |
| def to_tensor_thwc_uint8(clip: torch.Tensor): | |
| """ | |
| Inverse of torchvision.transforms._functional_video.to_tensor() | |
| Convert video clip in CTHW float32 (0~1.0) format to THWC uint8 (0~255) | |
| Args: | |
| clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) | |
| Return: | |
| clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) | |
| """ | |
| assert isinstance(clip, torch.Tensor) | |
| assert len(clip.shape) == 4 | |
| assert clip.dtype == torch.float32 | |
| return (clip * 255.0).permute(1, 2, 3, 0).byte() | |
| @register_transform("ToTensorTHWCUInt8") | |
| class ToTensorTHWCUInt8(ClassyTransform): | |
| def __call__(self, clip: torch.Tensor): | |
| return to_tensor_thwc_uint8(clip) | |
| @register_transform("VideoAutoAugmentTF") | |
| class VideoAutoAugmentTF(ClassyTransform): | |
| def __init__( | |
| self, | |
| policy_name, | |
| img_size, | |
| pixel_mean, | |
| interpolation, | |
| ): | |
| """ | |
| pixel_mean: list of three values of 0~1 | |
| """ | |
| if isinstance(img_size, tuple): | |
| img_size_min = min(img_size) | |
| else: | |
| img_size_min = img_size | |
| scaled_mean = [x * 255.0 for x in pixel_mean] | |
| aa_params = { | |
| "translate_const": int(img_size_min * 0.45), | |
| "img_mean": tuple(round(x) for x in scaled_mean), | |
| } | |
| if interpolation != "random": | |
| aa_params["interpolation"] = _pil_interp(interpolation) | |
| self.policy = auto_augment_policy(policy_name, aa_params) | |
| def __call__(self, vid): | |
| sub_policy = random.choice(self.policy) | |
| assert isinstance(vid, torch.Tensor), "Wrong type of input" | |
| if vid.dtype != torch.uint8: | |
| vid = vid * 255.0 | |
| vid = vid.numpy().astype(np.uint8) | |
| for op in sub_policy: | |
| vid = op(vid) | |
| # Return to Tensor uint8, range of 0~255 | |
| vid = torch.from_numpy(vid).type(torch.uint8) | |
| return vid |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment