Skip to content

Instantly share code, notes, and snippets.

@pangyuteng
Last active September 20, 2024 00:24
Show Gist options
  • Save pangyuteng/fdbf0e13cd9173dc11aabccb30f8a2ad to your computer and use it in GitHub Desktop.
Save pangyuteng/fdbf0e13cd9173dc11aabccb30f8a2ad to your computer and use it in GitHub Desktop.
keras sample data generator, augmentation of keypoints and mask with albumentations
# sample code to augment image,mask and keypoints with albumentations
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import albumentations as A
image = np.random.rand(256,256)
image[64:128,64:128]+=0.5
mask = np.zeros((256,256))
mask[64:128,64:128]=1
point = [64,128]
plt.subplot(131)
plt.imshow(image)
plt.subplot(132)
plt.imshow(mask)
plt.scatter(point[1],point[0])
aug_pipeline = A.Compose([
A.ShiftScaleRotate(),
],p=1,keypoint_params=A.KeypointParams('yx')
)
augmented = aug_pipeline(
image=image,
mask=mask,
keypoints=[point],
)
plt.subplot(121)
plt.imshow(augmented['image'])
plt.subplot(122)
plt.imshow(augmented['mask'])
plt.subplot(122)
plt.imshow(augmented['mask'])
plt.scatter(augmented['keypoints'][0][1],augmented['keypoints'][0][0])
import pandas as pd
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
import imageio
kwargs = dict(
rotation_range=10,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True,
vertical_flip=True,
cval=0,
fill_mode='constant',
)
datagen = ImageDataGenerator(**kwargs)
def readimage(image_file,augment):
x_sample = imageio.read(image_file)
x_sample = x_sample.astype(np.float)
minval,maxval = -1000,1000
x_sample = (x_sample-minval)/(maxval-minval)
x_sample = x_sample.clip(0,1)
x_sample = zoom(x_sample,[0.5,0.5,1])
if augment:
x_sample = datagen.random_transform(x_sample)
return x_sample
from keras.utils import Sequence
# https://github.com/keras-team/keras/issues/9707
class MySeriGenerator(Sequence):
def __init__(self, mydf,batch_size=8,shuffle=True,augment=False):
self.y = np.array([int(x) for x in mydf.contrast.values])
self.x = np.array([x for x in getattr(mydf,'image_file').values])
self.indices = np.arange(len(self.y))
self.batch_size = batch_size
self.shuffle = shuffle
self.augment = augment
def __len__(self):
return len(self.indices) // self.batch_size
def __getitem__(self, idx):
inds = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_x = self.x[inds]
batch_y = self.y[inds]
# read your data here using the batch lists, batch_x and batch_y
x = [readimage(filename,self.augment,) for filename in batch_x]
y = batch_y
return np.array(x), np.array(y)
def on_epoch_end(self):
if self.shuffle:
np.random.shuffle(self.indices)
#
# ref
# https://stackoverflow.com/questions/37340129/tensorflow-training-on-my-own-image
# https://medium.com/the-owl/creating-a-tf-dataset-using-a-data-generator-5e5564609e64
# https://github.com/keras-team/keras-io/blob/master/examples/generative/ddim.py
#
import os
import sys
from pathlib import Path
import numpy as np
import tensorflow as tf
import cv2
import imageio
from skimage.transform import resize
import matplotlib.pyplot as plt
dataset_repetitions= 5
image_size = 64
batch_size = 64
# for deep-lesion dataset
min_val,max_val = -1000,1000
def png_read(file_path):
file_path = file_path.decode('utf-8')
image = cv2.imread(file_path, -1) # -1 is needed for 16-bit image
image = (image.astype(np.int32) - 32768).astype(np.int16) # HU
image = image.astype(np.float32)
image = ((image-min_val)/(max_val-min_val)).clip(0,1)
image = np.expand_dims(image,axis=-1)
dummpy = np.array([0.0]).astype(np.float32)
return image, dummpy
def parse_fn_py(file_path):
image, dummy = tf.numpy_function(
func=png_read,
inp=[file_path],
Tout=[tf.float32, tf.float32],
)
image = tf.cast(image, tf.float32)
image = tf.tile(image, [1,1,3])
image = tf.image.resize(image, [image_size,image_size],antialias=True)
return image
def parse_fn(filename):
image_string = tf.io.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image = tf.cast(image_decoded, tf.float32) / 255.0
image = tf.image.resize(image, [image_size,image_size],antialias=True)
return image
def prepare_dataset():
directory = './celeba_gan/img_align_celeba'
path_list = [str(x) for x in Path(directory).rglob('*.jpg')]
print(len(path_list))
train_filenames = tf.constant(path_list[:-1000])
train_ds = tf.data.Dataset.from_tensor_slices(train_filenames).repeat(dataset_repetitions).shuffle(10 * batch_size).map(
parse_fn, num_parallel_calls=tf.data.AUTOTUNE
)
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(buffer_size=tf.data.AUTOTUNE)
val_filenames = tf.constant(path_list[-1000:])
val_ds = tf.data.Dataset.from_tensor_slices(val_filenames).repeat(dataset_repetitions).shuffle(10 * batch_size).map(
parse_fn, num_parallel_calls=tf.data.AUTOTUNE
)
val_ds = val_ds.batch(batch_size, drop_remainder=True).prefetch(buffer_size=tf.data.AUTOTUNE)
return train_ds, val_ds
train_dataset , val_dataset = prepare_dataset()
plt.figure(figsize=(10, 10))
for images in train_dataset.take(1):
for i in range(batch_size):
ax = plt.subplot(3, 3, i + 1)
plt.imshow((images[i,:].numpy()*255).astype("uint8"))
plt.axis("off")
if i > 7 :
break
os.makedirs('tmp',exist_ok=True)
plt.savefig(f"tmp/test.png")
plt.close()
@pangyuteng
Copy link
Author

@pangyuteng
Copy link
Author

pangyuteng commented Jul 9, 2022

to impl tf.data.Dataset, see below issue:
tensorflow/tensorflow#39523 (comment)

performance may be similar/comparable/needs more tweaking.
https://stackoverflow.com/questions/55852831/tf-data-vs-keras-utils-sequence-performance/59492947#59492947
https://www.tensorflow.org/guide/data_performance

alternatively, pass data to readimage via map & tf.numpy_function.
https://albumentations.ai/docs/examples/tensorflow-example

then follow below for multi-gpu training.
tensorflow/tensorflow#42146 (comment)

strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
train_dataset = tf.data.Dataset.BLAH
train_dataset = train_dataset.with_options(options).MOREBLAH

@pangyuteng
Copy link
Author

pangyuteng commented Nov 15, 2023

#
# REF
# https://gist.github.com/Lexie88rus/b6e66497a1b4b14aa01cc41e126a7c20
# https://www.kaggle.com/code/tachyon777/hubmap-tachyon-dataaugmentation-examples
#

augmentation_pipeline = A.Compose(
    [
        A.RandomBrightness(limit=0.2, p=0.75),
        A.RandomContrast(limit=0.2, p=0.75),
        A.OneOf([
            A.MotionBlur(blur_limit=5),
            A.MedianBlur(blur_limit=5),
            A.GaussianBlur(blur_limit=5),
            A.GaussNoise(var_limit=(5.0, 30.0)),
        ], p=0.7),
        A.OneOf([
            A.OpticalDistortion(distort_limit=1.0),
            A.GridDistortion(num_steps=5, distort_limit=1.),
            A.ElasticTransform(alpha=3),
        ], p=0.7),
        A.CLAHE(clip_limit=4.0, p=0.7),
        A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
        A.Resize(256, 256), #?
        A.Cutout(max_h_size=int(256 * 0.375), max_w_size=int(256 * 0.375), num_holes=1, p=0.7), 
    ],
    p = 1
)



@pangyuteng
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment