Last active
September 20, 2024 00:24
-
-
Save pangyuteng/fdbf0e13cd9173dc11aabccb30f8a2ad to your computer and use it in GitHub Desktop.
keras sample data generator, augmentation of keypoints and mask with albumentations
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# sample code to augment image,mask and keypoints with albumentations | |
import numpy as np | |
import matplotlib.pyplot as plt | |
%matplotlib inline | |
import albumentations as A | |
image = np.random.rand(256,256) | |
image[64:128,64:128]+=0.5 | |
mask = np.zeros((256,256)) | |
mask[64:128,64:128]=1 | |
point = [64,128] | |
plt.subplot(131) | |
plt.imshow(image) | |
plt.subplot(132) | |
plt.imshow(mask) | |
plt.scatter(point[1],point[0]) | |
aug_pipeline = A.Compose([ | |
A.ShiftScaleRotate(), | |
],p=1,keypoint_params=A.KeypointParams('yx') | |
) | |
augmented = aug_pipeline( | |
image=image, | |
mask=mask, | |
keypoints=[point], | |
) | |
plt.subplot(121) | |
plt.imshow(augmented['image']) | |
plt.subplot(122) | |
plt.imshow(augmented['mask']) | |
plt.subplot(122) | |
plt.imshow(augmented['mask']) | |
plt.scatter(augmented['keypoints'][0][1],augmented['keypoints'][0][0]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
import numpy as np | |
from keras.preprocessing.image import ImageDataGenerator | |
import imageio | |
kwargs = dict( | |
rotation_range=10, | |
width_shift_range=0.1, | |
height_shift_range=0.1, | |
horizontal_flip=True, | |
vertical_flip=True, | |
cval=0, | |
fill_mode='constant', | |
) | |
datagen = ImageDataGenerator(**kwargs) | |
def readimage(image_file,augment): | |
x_sample = imageio.read(image_file) | |
x_sample = x_sample.astype(np.float) | |
minval,maxval = -1000,1000 | |
x_sample = (x_sample-minval)/(maxval-minval) | |
x_sample = x_sample.clip(0,1) | |
x_sample = zoom(x_sample,[0.5,0.5,1]) | |
if augment: | |
x_sample = datagen.random_transform(x_sample) | |
return x_sample | |
from keras.utils import Sequence | |
# https://github.com/keras-team/keras/issues/9707 | |
class MySeriGenerator(Sequence): | |
def __init__(self, mydf,batch_size=8,shuffle=True,augment=False): | |
self.y = np.array([int(x) for x in mydf.contrast.values]) | |
self.x = np.array([x for x in getattr(mydf,'image_file').values]) | |
self.indices = np.arange(len(self.y)) | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.augment = augment | |
def __len__(self): | |
return len(self.indices) // self.batch_size | |
def __getitem__(self, idx): | |
inds = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size] | |
batch_x = self.x[inds] | |
batch_y = self.y[inds] | |
# read your data here using the batch lists, batch_x and batch_y | |
x = [readimage(filename,self.augment,) for filename in batch_x] | |
y = batch_y | |
return np.array(x), np.array(y) | |
def on_epoch_end(self): | |
if self.shuffle: | |
np.random.shuffle(self.indices) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# | |
# ref | |
# https://stackoverflow.com/questions/37340129/tensorflow-training-on-my-own-image | |
# https://medium.com/the-owl/creating-a-tf-dataset-using-a-data-generator-5e5564609e64 | |
# https://github.com/keras-team/keras-io/blob/master/examples/generative/ddim.py | |
# | |
import os | |
import sys | |
from pathlib import Path | |
import numpy as np | |
import tensorflow as tf | |
import cv2 | |
import imageio | |
from skimage.transform import resize | |
import matplotlib.pyplot as plt | |
dataset_repetitions= 5 | |
image_size = 64 | |
batch_size = 64 | |
# for deep-lesion dataset | |
min_val,max_val = -1000,1000 | |
def png_read(file_path): | |
file_path = file_path.decode('utf-8') | |
image = cv2.imread(file_path, -1) # -1 is needed for 16-bit image | |
image = (image.astype(np.int32) - 32768).astype(np.int16) # HU | |
image = image.astype(np.float32) | |
image = ((image-min_val)/(max_val-min_val)).clip(0,1) | |
image = np.expand_dims(image,axis=-1) | |
dummpy = np.array([0.0]).astype(np.float32) | |
return image, dummpy | |
def parse_fn_py(file_path): | |
image, dummy = tf.numpy_function( | |
func=png_read, | |
inp=[file_path], | |
Tout=[tf.float32, tf.float32], | |
) | |
image = tf.cast(image, tf.float32) | |
image = tf.tile(image, [1,1,3]) | |
image = tf.image.resize(image, [image_size,image_size],antialias=True) | |
return image | |
def parse_fn(filename): | |
image_string = tf.io.read_file(filename) | |
image_decoded = tf.image.decode_jpeg(image_string, channels=3) | |
image = tf.cast(image_decoded, tf.float32) / 255.0 | |
image = tf.image.resize(image, [image_size,image_size],antialias=True) | |
return image | |
def prepare_dataset(): | |
directory = './celeba_gan/img_align_celeba' | |
path_list = [str(x) for x in Path(directory).rglob('*.jpg')] | |
print(len(path_list)) | |
train_filenames = tf.constant(path_list[:-1000]) | |
train_ds = tf.data.Dataset.from_tensor_slices(train_filenames).repeat(dataset_repetitions).shuffle(10 * batch_size).map( | |
parse_fn, num_parallel_calls=tf.data.AUTOTUNE | |
) | |
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(buffer_size=tf.data.AUTOTUNE) | |
val_filenames = tf.constant(path_list[-1000:]) | |
val_ds = tf.data.Dataset.from_tensor_slices(val_filenames).repeat(dataset_repetitions).shuffle(10 * batch_size).map( | |
parse_fn, num_parallel_calls=tf.data.AUTOTUNE | |
) | |
val_ds = val_ds.batch(batch_size, drop_remainder=True).prefetch(buffer_size=tf.data.AUTOTUNE) | |
return train_ds, val_ds | |
train_dataset , val_dataset = prepare_dataset() | |
plt.figure(figsize=(10, 10)) | |
for images in train_dataset.take(1): | |
for i in range(batch_size): | |
ax = plt.subplot(3, 3, i + 1) | |
plt.imshow((images[i,:].numpy()*255).astype("uint8")) | |
plt.axis("off") | |
if i > 7 : | |
break | |
os.makedirs('tmp',exist_ok=True) | |
plt.savefig(f"tmp/test.png") | |
plt.close() | |
#
# REF
# https://gist.github.com/Lexie88rus/b6e66497a1b4b14aa01cc41e126a7c20
# https://www.kaggle.com/code/tachyon777/hubmap-tachyon-dataaugmentation-examples
#
augmentation_pipeline = A.Compose(
[
A.RandomBrightness(limit=0.2, p=0.75),
A.RandomContrast(limit=0.2, p=0.75),
A.OneOf([
A.MotionBlur(blur_limit=5),
A.MedianBlur(blur_limit=5),
A.GaussianBlur(blur_limit=5),
A.GaussNoise(var_limit=(5.0, 30.0)),
], p=0.7),
A.OneOf([
A.OpticalDistortion(distort_limit=1.0),
A.GridDistortion(num_steps=5, distort_limit=1.),
A.ElasticTransform(alpha=3),
], p=0.7),
A.CLAHE(clip_limit=4.0, p=0.7),
A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
A.Resize(256, 256), #?
A.Cutout(max_h_size=int(256 * 0.375), max_w_size=int(256 * 0.375), num_holes=1, p=0.7),
],
p = 1
)
beautiful torch albumentation example:
https://albumentations.ai/docs/examples/pytorch_semantic_segmentation/
"elastic deformation"
https://github.com/gvtulder/elasticdeform
https://github.com/wolny/pytorch-3dunet/blob/master/pytorch3dunet/augment/transforms.py#L138
ref/more links:
https://arxiv.org/abs/1504.04003
https://gist.github.com/ernestum/601cdf56d2b424757de5
https://github.com/charlychiu/U-Net/blob/master/elastic_transform.py
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
alternative 0. custom loss
https://stackoverflow.com/questions/64130293/custom-loss-function-in-keras-with-masking-array-as-input
alternative 1. custom train steps
https://stackoverflow.com/questions/64130293/custom-loss-function-in-keras-with-masking-array-as-input