Last active
          July 19, 2021 14:20 
        
      - 
      
- 
        Save bmabir17/f37bc164c3f45c1c1b6ec3db0aeb231d to your computer and use it in GitHub Desktop. 
    Transfer Learning Models for Training Classifier
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from collections import defaultdict | |
| import keras | |
| from tensorflow.keras.utils import Sequence | |
| import logging | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| from random import shuffle | |
| from tensorflow.keras.utils import to_categorical | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(name)s %(levelname)s:%(message)s') | |
| logger = logging.getLogger(__name__) | |
| class SmartDataGenerator(Sequence): | |
| def __init__(self, data, batch_size, target_resolution, n_channels, n_classes, balancing_augmentation, | |
| class_balancing = True, shuffle=True, resize_technique = Image.BICUBIC): | |
| """initialization""" | |
| self.target_resolution = target_resolution | |
| self.n_channels = n_channels | |
| self.batch_size = batch_size | |
| self.list_IDs, self.labels = data | |
| self.n_classes = n_classes | |
| self.shuffle = shuffle | |
| self.resize_technique = resize_technique | |
| self.class_balancing = class_balancing | |
| self.b_augment = balancing_augmentation | |
| self.balance_class() | |
| self.on_epoch_end() | |
| def __len__(self): | |
| ln = int(np.floor(len(self.list_IDs) / self.batch_size)) | |
| return ln | |
| def __getitem__(self, pivot_index): | |
| """Generate one batch of data""" | |
| indexes = self.indexes[pivot_index * self.batch_size : (pivot_index+1)*self.batch_size] | |
| list_IDs_tmp = [self.list_IDs[k] for k in indexes] | |
| X, y = self.__data_generation(list_IDs_tmp) | |
| return X, y | |
| def read_img(self, path): | |
| """read, resize and convert into numpy array""" | |
| img_obj = Image.open(path) | |
| img_obj = img_obj.resize(self.target_resolution, Image.BICUBIC) | |
| img = np.array(img_obj) | |
| return img | |
| def __data_generation(self, list_IDs_temp): | |
| """Generates data from disk files""" | |
| X = np.empty((self.batch_size, *self.target_resolution, self.n_channels )) | |
| y = np.empty((self.batch_size), dtype=int) | |
| for i, img_ID in enumerate(list_IDs_temp): | |
| if not Path(img_ID).exists() and Path(img_ID).name == '__aug__': | |
| tmp_img_ID = Path(img_ID).parent | |
| img = self.read_img(tmp_img_ID) | |
| img = self.b_augment(image=img)["image"] | |
| img = img.reshape((img.shape[0], img.shape[1], self.n_channels)) | |
| img = img.astype('float32') / 255 | |
| X[i,] = img | |
| y[i] = self.labels[str(img_ID)] | |
| elif Path(img_ID).exists(): | |
| img = self.read_img(img_ID) | |
| img = img.reshape((img.shape[0], img.shape[1], self.n_channels )) | |
| img = img.astype('float32') / 255 | |
| X[i,] = img | |
| y[i] = self.labels[str(img_ID)] | |
| else: | |
| raise Exception('File not found at {}th iteration. Please check {}'.format(i, img_ID)) | |
| y_cat = to_categorical(y, num_classes=self.n_classes) | |
| return X, y_cat | |
| def balance_class(self): | |
| if self.class_balancing == False: | |
| logging.info('Skipping data balancing') | |
| return | |
| else: | |
| logging.info('-------------------\n\nStarting Datasets balancing using SMOT...\n') | |
| # first track all file name in its belonging class | |
| cls_filename_map = defaultdict(list) | |
| for path in self.list_IDs: | |
| file_name = Path(path) | |
| cls = self.labels[str(path)] | |
| cls_filename_map[cls].append(file_name) | |
| # calculate how many synthetic file is needed in each class | |
| total = 0 | |
| max_cls_file_num = 0 | |
| cls_file_count_map = {} | |
| for cls in cls_filename_map.keys(): | |
| file_num = len(cls_filename_map[cls]) | |
| cls_file_count_map[cls] = file_num | |
| total += file_num | |
| if max_cls_file_num < file_num: | |
| max_cls_file_num = file_num | |
| dominate_pct = 4 # imbalce pct of the dominate class. dominate_pct = 0 means all other class are equal | |
| new_estimated_total = int(len(cls_filename_map) * (max_cls_file_num - (max_cls_file_num * dominate_pct / 100.0))) | |
| logging.info('Found {} data. More {} synthetic data will be generated (new total = {}).\t{}/class\n' | |
| .format(total, new_estimated_total-total, new_estimated_total, (new_estimated_total/len(cls_filename_map)))) | |
| for cls, cls_file_count in cls_file_count_map.items(): | |
| current_cls_filelist = cls_filename_map[cls] | |
| current_cls_file_count = len(current_cls_filelist) | |
| if current_cls_file_count < max_cls_file_num: | |
| more_needed = int((new_estimated_total / len(cls_filename_map)) - current_cls_file_count) | |
| logging.info('class-{} will need more {}.\tcurrently have {}'.format(cls, more_needed, current_cls_file_count)) | |
| for i in range(more_needed): | |
| index = i % current_cls_file_count | |
| tmp_file = current_cls_filelist[index] | |
| new_filename = Path(tmp_file) / '__aug__' # __aug__ is our defined flag to distinguish from real | |
| self.list_IDs.append(new_filename) | |
| self.labels[str(new_filename)] = cls | |
| else: | |
| logging.info( | |
| 'class-{} need not any synthetic data. Already it has {}'.format(cls, current_cls_file_count)) | |
| def on_epoch_end(self): | |
| """Updates indexes after each epoch""" | |
| self.indexes = np.arange(len(self.list_IDs)) | |
| if self.shuffle == True: | |
| np.random.shuffle(self.indexes) | |
| class TukiTaki: | |
| def get_class_name_from_hotencoding(hot_encoding): | |
| import json | |
| i = np.argmax(hot_encoding, axis=0) | |
| with open('Reverse_class_index_mapping.txt') as f: | |
| data = json.load(f) | |
| key = str(i) | |
| return data[key] | |
| class KfoldMaker: | |
| def __init__(self, dataset_dir, image_extensions): | |
| self.dataset_dir = dataset_dir | |
| self.image_extensions = image_extensions | |
| self.all_path_list, self.all_labels_dic = self.scan_dataset() | |
| def generate_folds(self, K): | |
| np.random.shuffle(self.all_path_list) | |
| # filename_grpups = self.person_wise_grouping(K) | |
| num_validation_samples = len(self.all_path_list) // K | |
| fold_list = [] | |
| logging.info('\n\nDividing dataset into {} folds...'.format(K)) | |
| for fold in range(K): | |
| # For general case | |
| validation_data_x = self.all_path_list[num_validation_samples*fold : num_validation_samples*(fold+1)] | |
| training_data_x = self.all_path_list[ : num_validation_samples*fold] \ | |
| + self.all_path_list[num_validation_samples*(fold+1) : ] | |
| validation_data_y = {} | |
| training_data_y = {} | |
| for x in validation_data_x: | |
| validation_data_y[str(x)] = self.all_labels_dic[str(x)] | |
| for x in training_data_x: | |
| training_data_y[str(x)] = self.all_labels_dic[str(x)] | |
| fold_list.append(((training_data_x, training_data_y), (validation_data_x, validation_data_y))) | |
| status_bar = ['_'for x in range(K)] | |
| status_bar[fold] = '#' | |
| logging.info('Generated K-fold metadata for fold: {} {}'.format(str(fold), ''.join(status_bar))) | |
| # import pickle | |
| # with open('tmp/k_fold_metadata_temporary.pickle', 'wb') as fp: | |
| # pickle.dump(fold_list, fp) | |
| return fold_list | |
| def scan_dataset(self): | |
| from pathlib import Path | |
| root_path = Path(self.dataset_dir) | |
| images_path_list = [] | |
| for extension in self.image_extensions: | |
| img_list = list(root_path.rglob('*.'+extension)) | |
| images_path_list.extend(img_list) | |
| logging.debug('{} images found with {} extension'.format(len(img_list), extension)) | |
| logging.info('Found {} images with {} extensions\n'.format(len(images_path_list), self.image_extensions)) | |
| label_dic = {} | |
| cls_id_counter = -1 | |
| cls_id_map = {} | |
| reverse_cls_id_map = {} | |
| for index, path in enumerate(images_path_list): | |
| cls_name = path.parent.name | |
| if cls_name not in cls_id_map: | |
| cls_id_counter += 1 | |
| cls_id_map[cls_name] = cls_id_counter | |
| reverse_cls_id_map[cls_id_counter] = cls_name | |
| label_dic[str(path)] = cls_id_counter | |
| # save mapping between original class name and new assigned index | |
| # it may be useful during testing/prediction | |
| import json | |
| with open('Reverse_class_index_mapping.txt', 'w') as file: | |
| file.write(json.dumps(reverse_cls_id_map)) | |
| logging.info('Detected {} correctly labeled images inside {}'.format(len(images_path_list), self.dataset_dir)) | |
| logging.info('Total {} class found. {}\n'.format(len(cls_id_map), cls_id_map)) | |
| return images_path_list, label_dic | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | def build_mobilenetv2_model(config): | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| from tensorflow.keras.models import Model | |
| from tensorflow.keras.applications import MobileNetV2 | |
| from tensorflow.keras.optimizers import Adam | |
| input_shape = (config.input_size, config.input_size, config.n_channels) | |
| # create MobileNetv2 model with ImageNet weights and with no prediction layer | |
| base_model=MobileNetV2(weights='imagenet',input_shape=input_shape,include_top=False) | |
| base_model.trainable=False | |
| base_model.summary() | |
| prediction_layer = tf.keras.layers.Dense(config.n_classes, activation= 'softmax') | |
| # Add prediction layers | |
| inputs = tf.keras.Input(shape=input_shape) | |
| x = base_model(inputs, training=False) | |
| x = layers.GlobalMaxPooling2D()(x) | |
| x = layers.Dense(1024,activation='relu')(x) | |
| x = layers.Dense(1024,activation='relu')(x) | |
| x = layers.Dense(512,activation='relu')(x) | |
| x = layers.Dense(256,activation='relu')(x) | |
| x = layers.Dense(128,activation='relu')(x) | |
| outputs = prediction_layer(x) | |
| model = tf.keras.Model(inputs, outputs) | |
| adam = Adam(lr=config.learning_rate) | |
| model.compile(optimizer= adam, loss='categorical_crossentropy', metrics=[ | |
| tf.keras.metrics.BinaryAccuracy(name='accuracy'), | |
| tf.keras.metrics.TruePositives(name='true_pos'), | |
| tf.keras.metrics.FalsePositives(name='false_pos'), | |
| tf.keras.metrics.TrueNegatives(name='true_neg'), | |
| tf.keras.metrics.FalseNegatives(name='false_neg'), | |
| tf.keras.metrics.Precision(name='precision'), | |
| tf.keras.metrics.Recall(name='recall'), | |
| tf.keras.metrics.AUC(name='auc'), | |
| tf.keras.metrics.AUC(name='prc', curve='PR')]) | |
| model.summary() | |
| return model | |
| def build_resnet_model(config): | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| from tensorflow.keras.models import Model | |
| from tensorflow.keras.applications.resnet50 import ResNet50 | |
| from tensorflow.keras.optimizers import Adam | |
| # create model | |
| base_model=ResNet50(weights='imagenet',input_shape=(config.input_size, config.input_size, config.n_channels),include_top=False) | |
| base_model.trainable=False | |
| x = base_model.output | |
| x = layers.GlobalMaxPooling2D()(x) | |
| x = layers.Dense(1024,activation='relu')(x) | |
| x = layers.Dense(1024,activation='relu')(x) | |
| x = layers.Dense(512,activation='relu')(x) | |
| x = layers.Dense(config.n_classes, activation= 'softmax')(x) | |
| model = Model(inputs = base_model.input, outputs = x) | |
| adam = Adam(lr=config.learning_rate) | |
| model.compile(optimizer= adam, loss='categorical_crossentropy', metrics=[ | |
| tf.keras.metrics.BinaryAccuracy(name='accuracy'), | |
| tf.keras.metrics.TruePositives(name='true_pos'), | |
| tf.keras.metrics.FalsePositives(name='false_pos'), | |
| tf.keras.metrics.TrueNegatives(name='true_neg'), | |
| tf.keras.metrics.FalseNegatives(name='false_neg'), | |
| tf.keras.metrics.Precision(name='precision'), | |
| tf.keras.metrics.Recall(name='recall'), | |
| tf.keras.metrics.AUC(name='auc'), | |
| tf.keras.metrics.AUC(name='prc', curve='PR')]) | |
| return model | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | from DatasetTools import SmartDataGenerator, KfoldMaker | |
| from PIL import Image | |
| from albumentations import ( | |
| Compose, HorizontalFlip, | |
| RandomBrightness, RandomContrast, | |
| ) | |
| from model import build_simple_model,build_resnet_model,build_resnet_model2,build_mobilenetv2_model,build_mobilenetv2_model2 | |
| import tensorflow as tf | |
| from tensorflow.python.keras import backend as K | |
| from tensorflow.keras.callbacks import EarlyStopping,ModelCheckpoint | |
| from tensorflow.keras.models import load_model | |
| import wandb | |
| from wandb.keras import WandbCallback | |
| # 1. Start a new run | |
| wandb.init(project='projectName', entity='userName') | |
| 2. Save model inputs and hyperparameters | |
| config = wandb.config | |
| config.epochs = 30 | |
| config.num_validation=100 | |
| config.batch_size=16 | |
| config.n_channels=3 | |
| config.n_classes=2 | |
| # config.k_fold=5 | |
| config.cleaned=True | |
| config.learning_rate=1e-06 | |
| config.model_type='mobilenet' # 'baseModel', 'resnet', 'mobilenet' | |
| if(config.model_type=='mobilenet'): | |
| config.input_size=224 #For mobilenet | |
| elif(config.model_type=='baseModel'): | |
| config.input_size=256 #For basic_model | |
| config.dataset_dir = 'data/train_clean1' | |
| config.model_save_path=f"saved_model/{config.model_type}/e{config.epochs}_es_s{config.input_size}_cleanData" | |
| # ---------------------------------------------------------------------------------------------------------------------- | |
| # ---------------------------------------------------------------------------------------------------------------------- | |
| # Parameters | |
| params = {'batch_size': config.batch_size, | |
| 'target_resolution': (config.input_size, config.input_size), | |
| 'n_channels': config.n_channels, | |
| 'n_classes': config.n_classes, | |
| 'balancing_augmentation': Compose([RandomContrast(limit=0.2, p=0.5), RandomBrightness(limit=0.2, p=0.5)] ), | |
| # 'balancing_augmentation': Compose([] ), | |
| 'shuffle': True, | |
| 'resize_technique': Image.BICUBIC} | |
| # NOTE: here KfoldMaker is only used to scan files and generate file list | |
| k_obj = KfoldMaker(dataset_dir=config.dataset_dir, image_extensions=['jpeg']) | |
| num_validation = config.num_validation # number of validation set you want | |
| X = k_obj.all_path_list | |
| Y = k_obj.all_labels_dic | |
| # print(Y) | |
| train_x = X[ : -num_validation] | |
| val_x = X[ num_validation : ] | |
| training_gen = SmartDataGenerator(data=(train_x, Y), **params) | |
| validation_gen = SmartDataGenerator(data=(val_x, Y), **params) | |
| # simple early stopping | |
| es = EarlyStopping(monitor='val_loss', mode='min', verbose=1,patience=5) | |
| mc = ModelCheckpoint('best_model.h5', monitor='val_accuracy', mode='max', verbose=1, save_best_only=True) | |
| if(config.model_type=='mobilenet'): | |
| model = build_mobilenetv2_model2(config) | |
| elif(config.model_type=='baseModel'): | |
| model = build_simple_model(config) | |
| log = model.fit(x=training_gen, | |
| validation_data=validation_gen, | |
| workers=9, # realtime loading with parallel processing | |
| epochs=config.epochs, verbose=1, | |
| callbacks=[ | |
| es, | |
| mc, | |
| WandbCallback( | |
| save_model=True | |
| )]) | |
| # load the Best saved model | |
| saved_model = load_model('best_model.h5') | |
| saved_model.save(config.model_save_path) | |
| model_json = saved_model.to_json() | |
| with open(f"{config.model_save_path}/arch.json", "w") as json_file: | |
| json_file.write(model_json) | |
| saved_model.save_weights(f"{config.model_save_path}/weights.h5") | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment