Created
June 25, 2021 14:39
-
-
Save theeluwin/c388e7edd547e8da84db640ec568b51f to your computer and use it in GitHub Desktop.
from @ihl7029
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import PIL | |
import PIL.ImageOps | |
import PIL.ImageEnhance | |
import PIL.ImageDraw | |
__all__ = ( | |
'augment_bound', | |
'check_augment_min_max', | |
'identity', | |
'autocontrast', | |
'equalize', | |
'invert', | |
'posterize', | |
'solarize', | |
'color', | |
'contrast', | |
'brightness', | |
'sharpness', | |
'rotate', | |
'translateX', | |
'translateY', | |
'shearX', | |
'shearY', | |
'RandAug', | |
) | |
augment_bound = { | |
'identity': (0, 0), | |
'autocontrast': (0, 0), | |
'equalize': (0, 0), | |
'invert': (0, 0), | |
'posterize': (0, 8), | |
'solarize': (0, 256), | |
'color': (0.0, 1.0), | |
'contrast': (0.0, 1.0), | |
'brightness': (0.0, 1.0), | |
'sharpness': (0.0, 1.0), | |
'rotate': (0.0, 180.0), | |
'translateX': (0.0, 1.0), | |
'translateY': (0.0, 1.0), | |
'shearX': (0.0, 1.0), | |
'shearY': (0.0, 1.0), | |
} | |
def check_augment_min_max(augment_list): | |
for op, v_min, v_max in augment_list: | |
sub, sup = augment_bound[op] | |
assert v_min >= sub, f"{op} min ({v_min} >= {sub})" | |
assert v_max <= sup, f"{op} max ({v_max} <= {sup})" | |
def _random_flip(v): | |
return v if random.random() < 0.5 else -v | |
def _affine(img, matrix, fillcolor): | |
return img.transform(img.size, PIL.Image.AFFINE, matrix, fillcolor=fillcolor) | |
def identity(img, _, **kwargs): | |
return img | |
def autocontrast(img, _, **kwargs): | |
return PIL.ImageOps.autocontrast(img) | |
def equalize(img, _, **kwargs): | |
return PIL.ImageOps.equalize(img) | |
def invert(img, _, **kwargs): | |
return PIL.ImageOps.invert(img) | |
def posterize(img, v, **kwargs): | |
return PIL.ImageOps.posterize(img, 8 - round(v)) | |
def solarize(img, v, **kwargs): | |
return PIL.ImageOps.solarize(img, 256 - round(v)) | |
def color(img, v, **kwargs): | |
v = _random_flip(v) | |
return PIL.ImageEnhance.Color(img).enhance(1 + v) | |
def contrast(img, v, **kwargs): | |
v = _random_flip(v) | |
return PIL.ImageEnhance.Contrast(img).enhance(1 + v) | |
def brightness(img, v, **kwargs): | |
v = _random_flip(v) | |
return PIL.ImageEnhance.Brightness(img).enhance(1 + v) | |
def sharpness(img, v, **kwargs): | |
v = _random_flip(v) | |
return PIL.ImageEnhance.Sharpness(img).enhance(1 + v) | |
def rotate(img, v, fillcolor='black'): | |
v = _random_flip(v) | |
return img.rotate(v, fillcolor=fillcolor) | |
def translateX(img, v, fillcolor='black'): | |
v = _random_flip(v * img.size[0]) | |
return _affine(img, (1, 0, v, 0, 1, 0), fillcolor) | |
def translateY(img, v, fillcolor='black'): | |
v = _random_flip(v * img.size[1]) | |
return _affine(img, (1, 0, 0, 0, 1, v), fillcolor) | |
def shearX(img, v, fillcolor='black'): | |
v = _random_flip(v) | |
return _affine(img, (1, v, 0, 0, 1, 0), fillcolor) | |
def shearY(img, v, fillcolor='black'): | |
v = _random_flip(v) | |
return _affine(img, (1, 0, 0, v, 1, 0), fillcolor) | |
class RandAug: | |
QUANTIZE_LEVEL = 10 | |
DEFAULT_AUGMENT_LIST = [ | |
('identity', 0, 0), | |
('autocontrast', 0, 0), | |
('equalize', 0, 0), | |
('posterize', 4, 8), | |
('solarize', 0, 128), | |
('color', 0.05, 0.95), | |
('contrast', 0.05, 0.95), | |
('brightness', 0.05, 0.95), | |
('sharpness', 0.05, 0.95), | |
('rotate', 0, 30), | |
('translateX', 0, 0.3), | |
('translateY', 0, 0.3), | |
('shearX', 0, 0.3), | |
('shearY', 0, 0.3), | |
] | |
def __init__(self, n=2, m=10, augment_list=None, fillcolor='white'): | |
self.n = int(n) | |
self.m = int(m) | |
self.augment_list = augment_list | |
self.fillcolor = fillcolor | |
if augment_list is None: | |
self.augment_list = self.DEFAULT_AUGMENT_LIST | |
check_augment_min_max(self.augment_list) | |
def transform(self, img, op, v): | |
return globals()[op](img, v, fillcolor=self.fillcolor) | |
def __call__(self, img): | |
ops = random.choices(self.augment_list, k=self.n) | |
for op, v_min, v_max in ops: | |
v = random.random() * self.m / self.QUANTIZE_LEVEL | |
v = v * (v_max - v_min) + v_min | |
img = self.transform(img, op, v) | |
return img |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment