Last active
December 22, 2023 17:59
-
-
Save ReemRashwan/8c92086d3104d01978a16e05ca93a165 to your computer and use it in GitHub Desktop.
Keras Dicom Images Data Generator and Augmenter from Dataframes (Benefits from ImageDataGenerator).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import pydicom | |
import cv2 | |
from sklearn.model_selection import train_test_split | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
from keras_preprocessing.image.dataframe_iterator import DataFrameIterator | |
# tested on tf 2.1 | |
class DCMDataFrameIterator(DataFrameIterator): | |
def __init__(self, *arg, **kwargs): | |
self.white_list_formats = ('dcm') | |
super(DCMDataFrameIterator, self).__init__(*arg, **kwargs) | |
self.dataframe = kwargs['dataframe'] | |
self.x = self.dataframe[kwargs['x_col']] | |
self.y = self.dataframe[kwargs['y_col']] | |
self.color_mode = kwargs['color_mode'] | |
self.target_size = kwargs['target_size'] | |
def _get_batches_of_transformed_samples(self, indices_array): | |
# get batch of images | |
batch_x = np.array([self.read_dcm_as_array(dcm_path, self.target_size, color_mode=self.color_mode) | |
for dcm_path in self.x.iloc[indices_array]]) | |
batch_y = np.array(self.y.iloc[indices_array].astype(np.uint8)) # astype because y was passed as str | |
# transform images | |
if self.image_data_generator is not None: | |
for i, (x, y) in enumerate(zip(batch_x, batch_y)): | |
transform_params = self.image_data_generator.get_random_transform(x.shape) | |
batch_x[i] = self.image_data_generator.apply_transform(x, transform_params) | |
# you can change y here as well, eg: in semantic segmentation you want to transform masks as well | |
# using the same image_data_generator transformations. | |
return batch_x, batch_y | |
@staticmethod | |
def read_dcm_as_array(dcm_path, target_size=(256, 256), color_mode='rgb'): | |
image_array = pydicom.dcmread(dcm_path).pixel_array | |
image_array = cv2.resize(image_array, target_size, interpolation=cv2.INTER_NEAREST) #this returns a 2d array | |
image_array = np.expand_dims(image_array, -1) | |
if color_mode == 'rgb': | |
image_array = cv2.cvtColor(image_array, cv2.COLOR_GRAY2RGB) | |
return image_array | |
# read data | |
# Assuming it has two cols: | |
# image_path: path to each image with its extension | |
# target: labels (here it is 0s and 1s) -> binary classification | |
df = pd.read_csv("yourDfPath.csv", dtype=str) | |
# split for testing | |
train_df, test_df = train_test_split(df, test_size=0.2) | |
# augmentation parameters | |
# you can use preprocessing_function instead of rescale in all generators | |
# if you are using a pretrained network | |
train_augmentation_parameters = dict( | |
rescale=1.0/255.0, | |
rotation_range=10, | |
zoom_range=0.2, | |
horizontal_flip=True, | |
fill_mode='nearest', | |
brightness_range = [0.8, 1.2], | |
validation_split = 0.2 | |
) | |
valid_augmentation_parameters = dict( | |
rescale=1.0/255.0, | |
validation_split = 0.2 | |
) | |
test_augmentation_parameters = dict( | |
rescale=1.0/255.0 | |
) | |
# training parameters | |
BATCH_SIZE = 32 | |
CLASS_MODE = 'binary' | |
COLOR_MODE = 'grayscale' | |
TARGET_SIZE = (300, 300) | |
EPOCHS = 10 | |
SEED = 1337 | |
train_consts = { | |
'seed': SEED, | |
'batch_size': BATCH_SIZE, | |
'class_mode': CLASS_MODE, | |
'color_mode': COLOR_MODE, | |
'target_size': TARGET_SIZE, | |
'subset': 'training' | |
} | |
valid_consts = { | |
'seed': SEED, | |
'batch_size': BATCH_SIZE, | |
'class_mode': CLASS_MODE, | |
'color_mode': COLOR_MODE, | |
'target_size': TARGET_SIZE, | |
'subset': 'validation' | |
} | |
test_consts = { | |
'batch_size': 1, # should be 1 in testing | |
'class_mode': CLASS_MODE, | |
'color_mode': COLOR_MODE, | |
'target_size': TARGET_SIZE, # resize input images | |
'shuffle': False | |
} | |
# Using the training phase generators | |
train_augmenter = ImageDataGenerator(**train_augmentation_parameters) | |
valid_augmenter = ImageDataGenerator(**valid_augmentation_parameters) | |
train_generator = DCMDataFrameIterator(dataframe=train_df, | |
x_col='image_path', | |
y_col='target', | |
image_data_generator=train_augmenter, | |
**train_consts) | |
valid_generator = DCMDataFrameIterator(dataframe=train_df, | |
x_col='image_path', | |
y_col='target', | |
image_data_generator=valid_augmenter, | |
**valid_consts) | |
# define model architecture like how you normally do | |
model = ... | |
# training | |
history = model.fit_generator( | |
generator=train_generator, | |
steps_per_epoch=len(train_generator), | |
epochs=EPOCHS, | |
validation_data=valid_generator, | |
validation_steps=len(valid_generator) | |
) | |
# Using the testing generator to evaluate the model after training | |
test_augmenter = ImageDataGenerator(**test_augmentation_parameters) | |
test_generator = DCMDataFrameIterator(dataframe=test_df, | |
x_col='image_path', | |
y_col='target', | |
image_data_generator=test_augmenter, | |
**test_consts) | |
test_loss, test_accuracy = model.evaluate(test_generator, steps=len(test_generator)) |
both
keras_preprocessing
andtensorflow.keras.preprocessing
deprecated.
Thanks for the note, this may be a reference for future contributors.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
both
keras_preprocessing
andtensorflow.keras.preprocessing
deprecated.