Last active
February 22, 2022 11:40
-
-
Save fclesio/c6f1861f9ed7713ddff987d918ebcc64 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import os | |
import sys | |
sys.path.append(os.getcwd() + os.sep + os.pardir) | |
import random | |
import time | |
import imageio | |
import numpy as np | |
import imgaug as ia | |
from imgaug.augmentables.batches import UnnormalizedBatch | |
from imgaug import augmenters as iaa | |
random.seed(42) | |
ia.seed(42) | |
file_paths = { | |
"source": "src/main/data/source", | |
"reshaped": "src/main/data/reshaped", | |
"test_augmented": "src/main/data/test_augmented", | |
"train_augmented": "src/main/data/train_augmented", | |
"validation_augmented": "src/main/data/validation_augmented" | |
} | |
TRAIN_SPLIT = 0.80 | |
TEST_SPLIT = 0.10 | |
VALIDATION_SPLIT = 0.10 | |
NB_BATCHES = 10 | |
BATCH_SIZE = 200 | |
NB_TRAIN_SAMPLES = 20000 | |
def get_train_test_validation_file_lists(files_list): | |
qty_valid_source_files = len(files_list) | |
qty_train_records = int(qty_valid_source_files * TRAIN_SPLIT) | |
qty_test_records = int(qty_valid_source_files * TEST_SPLIT) | |
qty_validation_records = int(qty_valid_source_files * VALIDATION_SPLIT) | |
list_train_files = files_list[:qty_train_records] | |
del files_list[:qty_train_records] | |
list_test_files = files_list[:qty_test_records] | |
del files_list[:qty_test_records] | |
list_validation_files = files_list[:qty_validation_records] | |
del files_list[:qty_validation_records] | |
return list_train_files, list_test_files, list_validation_files | |
def get_images_batch_stack(image, batch_size=BATCH_SIZE): | |
return [np.copy(image) for _ in range(batch_size)] | |
def get_image_batches(images, nb_batches=NB_BATCHES): | |
return [UnnormalizedBatch(images=images) for _ in range(nb_batches)] | |
def reshape_images(image_name, | |
source_path=file_paths["source"], | |
reshaped_path=file_paths["reshaped"], | |
): | |
image_source_path = source_path + '/' + image_name | |
image_reshaped_path = reshaped_path + '/' + image_name | |
resize_aug \ | |
= iaa.Sequential([ | |
iaa.Resize(256, interpolation=["linear"]) | |
]) | |
raw_image = imageio.imread(image_source_path) | |
resize_aug_image \ | |
= resize_aug(image=raw_image) | |
imageio.imwrite(image_reshaped_path, resize_aug_image) | |
def reshape_save_source_images(file_path=file_paths["source"]): | |
list_source_files \ | |
= os.listdir(file_path) | |
list_valid_source_files \ | |
= [s for s in list_source_files if "jpg" in s] | |
for image in list_valid_source_files: | |
reshape_images(image_name=image) | |
def train_test_validation_split_sets(file_path=file_paths["reshaped"]): | |
list_reshaped_files \ | |
= os.listdir(file_path) | |
random.shuffle(list_reshaped_files) | |
qty_reshaped_files = len(list_reshaped_files) | |
if qty_reshaped_files < 10: | |
raise Exception(f'Less than 10 files in the reshaped folder. Number of files: {qty_reshaped_files}') | |
list_train_files, list_test_files, list_validation_files \ | |
= get_train_test_validation_file_lists(list_reshaped_files) | |
return list_train_files, list_test_files, list_validation_files | |
random_mixed_effects_aug \ | |
= iaa.Sequential( | |
[iaa.Affine(rotate=(-60, 60)), | |
iaa.AdditiveGaussianNoise(scale=(10, 60)), | |
iaa.Crop(percent=(0, 0.2)) | |
]) | |
def generate_batch_augmentation(image_name, | |
augmentation_effect, | |
reshaped_path=file_paths["reshaped"]): | |
image_reshaped_path = reshaped_path + '/' + image_name | |
raw_image = imageio.imread(image_reshaped_path) | |
stacked_images = get_images_batch_stack(raw_image) | |
image_batches = get_image_batches(stacked_images) | |
batches_aug \ | |
= list(augmentation_effect.augment_batches(image_batches, | |
background=True)) | |
return batches_aug | |
def generate_mix_effects_aug( | |
image_name_with_extention, | |
effect_name, | |
effect_object, | |
split_set, | |
destination_path, | |
reshaped_path=file_paths["reshaped"], | |
nb_batchs=NB_BATCHES, | |
batch_sizes=BATCH_SIZE): | |
image_name \ | |
= image_name_with_extention.split('.')[0] | |
image_aug_batch \ | |
= generate_batch_augmentation( | |
image_name_with_extention, | |
effect_object | |
) | |
for nb_batch in range(0, nb_batchs): | |
for batch_size in range(0, batch_sizes): | |
image_name_destination_path_with_effect \ | |
= destination_path + f'/{image_name}_{effect_name}_nb_batch{nb_batch}_batch_size_{batch_size}_{split_set}.jpg' | |
imageio.imwrite( | |
image_name_destination_path_with_effect, | |
image_aug_batch[nb_batch].images_aug[batch_size] | |
) | |
def generate_files_mixed_effects(list_set_files, effect_aug, | |
effect_name, effect_object, | |
split_set, destination_path): | |
start_time = time.time() | |
for image in list_set_files: | |
effect_aug( | |
image_name_with_extention=image, | |
effect_name=effect_name, | |
effect_object=effect_object, | |
split_set=split_set, | |
destination_path=destination_path | |
) | |
time_elapsed = time.time() - start_time | |
logging.info(f'Time elapsed - {split_set} set: {time.strftime("%H:%M:%S", time.gmtime(time_elapsed))}') | |
def main(): | |
logging.info("Start Jaguaritica Augmentaton wrapper...") | |
logging.info("Reshape source images for 256x256 definition") | |
reshape_save_source_images() | |
logging.info("Split reshaped images in Train, Test and Validation sets") | |
list_train_files, list_test_files, list_validation_files \ | |
= train_test_validation_split_sets() | |
logging.info("Generate mixed effects for the Training Set") | |
generate_files_mixed_effects( | |
list_set_files=list_train_files, | |
effect_aug=generate_mix_effects_aug, | |
effect_name='random_mixed_effects_aug', | |
effect_object=random_mixed_effects_aug, | |
split_set='train', | |
destination_path=file_paths["train_augmented"] | |
) | |
logging.info("Generate mixed effects for the Test Set") | |
generate_files_mixed_effects( | |
list_set_files=list_test_files, | |
effect_aug=generate_mix_effects_aug, | |
effect_name='random_mixed_effects_aug', | |
effect_object=random_mixed_effects_aug, | |
split_set='test', | |
destination_path=file_paths["test_augmented"] | |
) | |
logging.info("Generate mixed effects for the Validation Set") | |
generate_files_mixed_effects( | |
list_set_files=list_validation_files, | |
effect_aug=generate_mix_effects_aug, | |
effect_name='random_mixed_effects_aug', | |
effect_object=random_mixed_effects_aug, | |
split_set='validation', | |
destination_path=file_paths["validation_augmented"] | |
) | |
logging.info("Jaguaritica Augmentaton wrapper finished") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment