Last active
December 21, 2018 15:48
-
-
Save ravnoor/771f14996f77c49c3a11d38886fb6a0d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/andreimouraviev/Mets/blob/a8ce43b335584187f0728a4b3975c679ccf06cbc/UNET_2D_AUG17.py | |
# coding: utf-8 | |
# In[1]: | |
import os | |
os.chdir(r'/home/amourav/Python') | |
import sklearn | |
from UTILS import * | |
K.set_image_dim_ordering('tf') | |
K.set_image_data_format('channels_last') | |
os.environ["CUDA_VISIBLE_DEVICES"]="6" | |
#set_gpu_limit(0.5) | |
# In[2]: | |
def get_Unet(img_rows, img_cols,LossF ,Metrics,Optimizer=Adam(1e-5),DropP=0,reg=0.000,batch_norm=False): | |
#init_W = keras.initializers.RandomNormal(mean=0.001, stddev=0.08, seed=7) | |
#init_B = keras.initializers.RandomNormal(mean=0.001, stddev=0.004, seed=7) | |
L2 = keras.regularizers.l2(reg) | |
print 'Opti {0}, DropP {1}, Loss {2}, reg {3}'.format(Optimizer, DropP, LossF, reg) | |
inputs = Input(( img_rows, img_cols, 1)) | |
if batch_norm: inputs = BatchNormalization(inputs) | |
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(inputs) | |
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv1) | |
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) | |
pool1 = Dropout(DropP)(pool1) | |
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(pool1) | |
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv2) | |
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) | |
pool2 = Dropout(DropP)(pool2) | |
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(pool2) | |
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv3) | |
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) | |
pool3 = Dropout(DropP)(pool3) | |
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(pool3) | |
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv4) | |
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4) | |
pool4 = Dropout(DropP)(pool4) | |
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(pool4) | |
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv5) | |
up6 = concatenate([Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same',kernel_regularizer=L2)(conv5), conv4],name='up6', axis=3) | |
up6 = Dropout(DropP)(up6) | |
conv6 = Conv2D(256,(3, 3), activation='relu', padding='same',kernel_regularizer=L2)(up6) | |
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv6) | |
up7 = concatenate([Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same',kernel_regularizer=L2)(conv6), conv3],name='up7', axis=3) | |
up7 = Dropout(DropP)(up7) | |
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(up7) | |
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv7) | |
up8 = concatenate([Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same',kernel_regularizer=L2)(conv7), conv2],name='up8', axis=3) | |
up8 = Dropout(DropP)(up8) | |
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(up8) | |
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv8) | |
up9 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same',kernel_regularizer=L2)(conv8), conv1],name='up9',axis=3) | |
up9 = Dropout(DropP)(up9) | |
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(up9) | |
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv9) | |
conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9) | |
model = Model(inputs=inputs, outputs=conv10) | |
model.compile(optimizer=Optimizer, loss=LossF, metrics=Metrics) | |
return model | |
# In[3]: | |
# CALLBACKS | |
class Print_Loss(keras.callbacks.Callback): | |
def on_train_begin(self, logs={}): | |
pass | |
def on_epoch_end(self, epoch, logs={}): | |
print '-epoch {0}'.format(epoch) ,'loss {0}'.format( logs.get('loss') ), 'val_loss {0}'.format( logs.get('val_loss') ) | |
def train_UNET(Prep_func,data,target,patient_id,Batch_Norm,Data_Path_out,Optimizer,CALLBACK,Nepochs, Batch_size,Nsplits,lossF,METRICS,Verbose=0,DropP=0,S=1,ClassW=None,shw_plt=0,splitN=None ,SkipTo = None,Nepochs_Pre=0,Sample_Weights=None,Reg=0.000,TESTMODE=False): | |
if TESTMODE: data,target,patient_id,Nsplits,Nepochs = data[0:100],target[0:100],patient_id[0:100],1,2 | |
os.chdir(Data_Path_out) | |
MODEL = [] | |
Model_List = [] | |
VAL_PATIENTS = [] | |
INDXS = get_splits_idxs(data,patient_id,Nsplits) | |
for i,TRAIN_TEST_SPLIT in enumerate( INDXS ): | |
Train_idx,Test_idx = TRAIN_TEST_SPLIT | |
print 'shuffle {0} \n'.format(i) | |
if splitN==i or not(i>=SkipTo): | |
print 'skipping {0} \n'.format(i) | |
continue | |
data_train,target_train,patient_train,data_valid,target_valid,patient_valid = split_data(data,target,patient_id,TRAIN_TEST_SPLIT) | |
Sample_Weights_trn = None | |
if not(Sample_Weights is None):Sample_Weights_trn = Sample_Weights[Train_idx] | |
if os.path.isfile('unet_{0}.hdf5'.format(i) ): | |
model = Load_Model('unet_{0}.hdf5'.format(i),{'Tversky':Tversky,'W_Bin_C_Entropy':W_Bin_C_Entropy}) | |
print 'loading model..' | |
else: | |
model = get_Unet(img_rows,img_cols,lossF,METRICS,Optimizer,DropP,reg=Reg,batch_norm=Batch_Norm) | |
''' #Pretrain | |
if Nepochs_Pre>0: | |
preT_data,PreT_target = get_BRATS() | |
preT_data_prep,PreT_target_prep,_,_=Prep_func(preT_data,PreT_target,data_valid,target_valid) | |
custom2 = CUSTOMCALL() | |
model.fit(preT_data_prep, PreT_target_prep, batch_size=Batch_size, epochs=Nepochs_Pre,\ | |
verbose=Verbose, shuffle=True,callbacks=[custom2],class_weight=None,validation_split=0.1) | |
''' | |
#Prep data for unet | |
data_train_prep,target_train_prep,data_valid_prep,target_valid_prep=Prep_func(data_train,target_train,data_valid,target_valid) | |
print_info(data_train_prep,target_train_prep,'train') | |
print_info(data_valid_prep,target_valid_prep,'valid') | |
# Shuffle | |
shuffled_imgs,shuffled_masks = shuffle(data_train_prep,target_train_prep) | |
shuffled_imgs_validation,shuffled_masks_validation = shuffle(data_valid_prep,target_valid_prep) | |
#if shw_plt: plot_some_Data(shuffled_imgs,shuffled_masks,n=4,Figsize=(30,60),Sride=5) | |
if shw_plt and i==1: sample_plot(shuffled_imgs,shuffled_masks,n=5,Figs=10) | |
# Callbacks | |
if not(Verbose): CALLBACK.append(Print_Loss() ) | |
model_checkpoint = ModelCheckpoint('unet_{0}.hdf5'.format(i), monitor='val_loss', save_best_only=True) | |
log_loss = Record_LOSS_METRICS() | |
log_loss.set_split(i) | |
CALLBACKs = CALLBACK + [model_checkpoint,log_loss] | |
#Train UNET | |
t1 = time.time() | |
print('-'*30,'\n' ,'Fitting model...','\n',('-'*30)) | |
history_callback = model.fit(shuffled_imgs, shuffled_masks, batch_size=Batch_size, epochs=Nepochs,\ | |
verbose=Verbose, shuffle=True,callbacks=CALLBACKs,class_weight=ClassW,\ | |
validation_data=(shuffled_imgs_validation,shuffled_masks_validation),\ | |
sample_weight=Sample_Weights_trn) | |
t2 = time.time() | |
#Save Model | |
if S: save_model(model,i) | |
print('----------------','\n','train time {0} h'.format( (t2-t1)/3600 ),'\n','----------------') | |
print 'Patients: {0}'.format( np.unique(patient_valid) ),'\r\n' | |
VAL_PATIENTS.append(np.unique(patient_valid)) | |
if shw_plt: plot_loss(model.history.history) | |
img_msk_prediction = model.predict(shuffled_imgs_validation, verbose=0 , batch_size=1) | |
if shw_plt: sample_results_(shuffled_imgs_validation,shuffled_masks_validation,img_msk_prediction,n=15\ | |
,Figsize=(20,40)) | |
scores = model.evaluate(shuffled_imgs, shuffled_masks, batch_size=1, verbose=0) | |
for score,metric in zip(scores,model.metrics_names): | |
print '{0} score: {1}'.format(metric,score) | |
MODEL.append(scores ) | |
#model_eval(shuffled_masks,img_msk_prediction) | |
#print 'model eval \n {0} \n '.format(MODEL[0]) | |
MODEL.append(model) | |
MODEL.append(history_callback.history) | |
Model_List.append(MODEL) | |
print '\n','----------------','\n','----------------' | |
del model | |
del MODEL | |
return Model_List | |
class Record_LOSS_METRICS(keras.callbacks.Callback): | |
def set_split(self,i_split): | |
self.i_split=i_split | |
def on_train_begin(self, logs={}): | |
FID = open('UNET_loss_{0}.txt'.format(self.i_split), 'w') | |
FID.write('RECORDING LOSS: \r\n') | |
FID.close() | |
FID2 = open('UNET_metrics_{0}.txt'.format(self.i_split), 'w') | |
FID2.write('RECORDING METRICS: \r\n') | |
FID2.close() | |
def on_epoch_end(self, epoch, logs={}): | |
FID = open('UNET_loss_{0}.txt'.format(self.i_split), 'a') | |
FID.write('{0} train {1} val {2} \r\n'.format(epoch,logs.get('loss'),logs.get('val_loss')) ) | |
FID.close() | |
FID2 = open('UNET_metrics_{0}.txt'.format(self.i_split), 'a') | |
FID2.write('{0} FNR {1} FPR {1} TNR {2} TPR {3} Precision {4} Dice {5} \r\n'.\ | |
format(epoch,logs.get('FNR'),logs.get('FPR'),logs.get('TNR'),logs.get('TPR'),logs.get('Precision'),logs.get('dice_coef_loss') )) | |
FID2.close() | |
''' | |
schedule = 10 | |
decay = .5 | |
def scheduler(epoch): | |
LR = self.lr.get_value() | |
if epoch == 1: self.losses_ = [] | |
if epoch == schedule: | |
self.losses_.append(logs.get('loss')) | |
if self.losses_[-1]>self.losses_[-2] and self.losses_[-1]>self.losses_[-3]: | |
LR = decay*self.lr.get_value() | |
return LR | |
change_lr = LearningRateScheduler(scheduler) | |
''' | |
def get_class_weights(target): | |
target=target/target.max() | |
class_frequencies = np.array([ np.sum(target == 0), np.sum(target == 1) ]) | |
class_weights = (class_frequencies[[1,0]])**0.25 | |
class_weights = class_weights / np.sum(class_weights) * 2. | |
class_weights = class_weights.astype(np.float32) | |
return class_weights | |
# In[4]: | |
Data_Path = r'/home/amourav/Brain Data/Processed for UNET/2D/data_percentile_norm_top_99.9_bottom_0.1_shape_512_skull str' | |
data,target,patient_id,img_rows,img_cols,class_freq,data_info = get_Data(Data_Path) | |
class_W = get_class_weights(target) | |
print class_W | |
#sample_plot(data,target,n=10,Figs=5) | |
# In[5]: | |
#Data Augmentation | |
AUGMENTATION = False | |
Aug_batches = 1 | |
if AUGMENTATION: | |
data,target,patient_id = AUGMENT_DATA('Augmented_conservative','conservative',data,target,patient_id,N_Batches=Aug_batches) | |
# In[6]: | |
#%% Loss/Metrics Functions | |
Tversky, Tversky_loss = get_Tversky(alpha = .3,beta= .7,verb=0) | |
W_Bin_C_Entropy = get_W_Bin_C_Entropy(class_W) | |
Metrics = [dice_coef_loss ,Tversky,W_Bin_C_Entropy,Precision,FPR,FNR,TPR,TNR] | |
LOSS_Func = dice_coef_loss | |
Data_Path_out = '/home/amourav/Python/Mets/unet/results/data in-{0}'.format(os.path.split(Data_Path)[-1] ) | |
print Data_Path_out | |
if not os.path.exists(Data_Path_out): | |
os.makedirs(Data_Path_out) | |
os.chdir(Data_Path_out) | |
#PARAMETERS | |
#Optimizer = keras.optimizers.Adadelta(lr=1.0) | |
lr= 1e-5 | |
Optimizer=Adam(lr ) | |
Nsplits=5 | |
Batch_size = 10 | |
Nepochs = 100 | |
Nepochs_Pre = 0 | |
Patience = 15 | |
Stop_delta = 0.02 | |
Prep_func=normalize | |
skip = None #skip split | |
skipTo = None | |
Reg_coef=0 | |
DropPerc= 0 | |
Batch_Norm = False | |
#CALLBACKS | |
print 'lr {0}'.format(lr) | |
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=Stop_delta, patience=Patience, verbose=1, mode='auto') | |
early_stop2 = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.005, patience=10, verbose=1, mode='auto') | |
reduce_lr=keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4, verbose=1, mode='auto', epsilon=0.001, cooldown=0, min_lr=0) | |
reduce_lr2=keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.25, patience=6, verbose=1, mode='auto', epsilon=0.01, cooldown=0, min_lr=0) | |
reduce_lr3=keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=8, verbose=1, mode='auto', epsilon=0.1, cooldown=0, min_lr=0) | |
reduce_LR=[reduce_lr,reduce_lr2,reduce_lr3] | |
CALLBACK=[early_stop,early_stop2]+ reduce_LR | |
testMODE=True | |
# In[7]: | |
Model_List = train_UNET(Prep_func,data,target,patient_id, Batch_Norm,Data_Path_out,Optimizer,CALLBACK,Nepochs,Batch_size, Nsplits,lossF=LOSS_Func,METRICS = Metrics,Verbose=0,DropP=DropPerc,S=0,ClassW=None,shw_plt=0, splitN=skip,SkipTo=skipTo, Nepochs_Pre=0,Sample_Weights=None, Reg=Reg_coef,TESTMODE=testMODE) | |
# In[ ]: | |
os.chdir(Data_Path_out) | |
filename = 'UNET_INFO.txt' | |
target = open(filename, 'w') | |
target.write('UNET_INFO') | |
target.write(' \r\n') | |
target.write('N_folds {0}'.format(Nsplits) ) | |
target.write(' \r\n') | |
target.write( time.strftime("_%d_%m_%Y_") ) | |
target.write(' \r\n') | |
target.write('Batch_size {0}'.format( Batch_size) ) | |
target.write(' \r\n') | |
target.write(' Nepochs {0}'.format(Nepochs) ) | |
target.write(' \r\n') | |
target.write('Pretrain Nepochs {0}'.format(Nepochs_pre) ) | |
target.write(' \r\n') | |
target.write(' \r\n') | |
target.write(' \r\n') | |
target.write('{0}: img / target').format('training') | |
target.write(' \r\n') | |
target.write('shape: {0}/{1}'.format(data_train_prep.shape, target_train_prep.shape) ) | |
target.writeint('max-min {0}-{1} , {2}-{3}'.format(data_train_prep.max(),data_train_prep.min(),target_train_prep.max(),target_train_prep.min() ) ) | |
target.write('dtype {0}/{1}'.format(data_train_prep.dtype , target_train_prep.dtype ) ) | |
target.write(' \r\n') | |
target.write(' \r\n') | |
target.write('Augmentation {0}, Nbatches {1}'.format(AUGMENTATION) ) | |
target.write(' \r\n') | |
target.write('loss {0}'.format(LOSS_Func) ) | |
target.write(' \r\n') | |
target.write('Data_Info') | |
target.write( str( data_info) ) | |
target.write(' \r\n') | |
target.close() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/andreimouraviev/Mets/blob/251322d2ae9561b32ac11af829ab6535e1b53e87/UTILS.py | |
import numpy as np | |
import os | |
import matplotlib.pyplot as plt | |
import time | |
import pickle | |
from sklearn.model_selection import GroupShuffleSplit | |
import sklearn | |
import skimage | |
import skimage.transform as tr | |
import scipy.ndimage | |
from scipy.ndimage.interpolation import map_coordinates | |
from scipy.ndimage.filters import gaussian_filter | |
from keras import backend as K | |
import tensorflow as tf | |
import keras | |
from keras.layers import Input, Dense, Convolution2D, merge, Conv2D,Conv2DTranspose | |
from keras.layers import MaxPooling2D, UpSampling2D,Dropout, Flatten,concatenate | |
from keras.models import Model | |
from keras.utils import np_utils | |
from keras import backend as K | |
from keras.models import model_from_json | |
from keras.optimizers import Adam | |
from keras.callbacks import ModelCheckpoint, LearningRateScheduler | |
from keras.regularizers import l1_l2 | |
from keras.layers.normalization import BatchNormalization | |
import keras | |
import sklearn.metrics | |
from keras.models import load_model | |
#%% SEED | |
np.random.seed(7) | |
tf.set_random_seed(7) | |
print '...Seed Set' | |
#%% LOSS | |
smooth = 1. | |
smooth2 = 1e-5 | |
#def confusion(TRUE_LIST,PRED_LIST): | |
# for i, (y_true,y_pred) in enumerate(zip(TRUE_LIST,PRED_LIST )): | |
def FP(y_true,y_pred): | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
return np.sum( y_pred_f*(1-y_true_f) ) | |
def FN(y_true,y_pred): | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
return np.sum( (1-y_pred_f)*y_true_f ) | |
def TP(y_true,y_pred): | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
return np.sum( y_pred_f*y_true_f ) | |
def TN(y_true,y_pred): | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
return np.sum( (1-y_pred_f)*(1-y_true_f) ) | |
def FPR(y_true,y_pred): | |
#fallout | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
return 1.0*(K.sum( y_pred_f*(1-y_true_f) ) )/(K.sum(1-y_true_f)+smooth2) | |
def FNR(y_true,y_pred): | |
#miss rate | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
return 1.0*( K.sum( (1-y_pred_f)*y_true_f ) )/ ( K.sum( y_true_f ) +smooth2 ) | |
def TPR(y_true,y_pred): | |
#sensitivity, recall, hit rate | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
return 1.0*( K.sum( y_pred_f*y_true_f ) )/ ( K.sum( y_true_f ) +smooth2 ) | |
def TNR(y_true,y_pred): | |
#specificity | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
return 1.0*( K.sum( (1-y_pred_f)*(1-y_true_f) ) )/ ( K.sum( 1-y_true_f ) +smooth2 ) | |
#return np.sum( (1.-y_pred_f)*y_true_f )/(1.0*np.sum( y_true_f )+smooth2) | |
#return 1-(np.sum( y_pred_f*y_true_f )+smooth2)/(1.0*np.sum(y_true_f)+smooth2 ) # 1 - TPR | |
def Precision(y_true,y_pred): | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
return (1.0*K.sum( y_pred_f*y_true_f ) +smooth2)/(K.sum( y_pred_f*y_true_f ) + K.sum( y_pred_f*(1.0-y_true_f) )+smooth2) | |
def get_Tversky(alpha = .3,beta= .7,verb=0): | |
def Tversky(y_true, y_pred): | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
intersection = K.sum(y_true_f * y_pred_f) | |
G_P = alpha*K.sum( (1-y_true_f) * y_pred_f ) # G not P | |
P_G = beta*K.sum( y_true_f * (1-y_pred_f) ) # P not G | |
return (intersection + smooth )/(intersection + smooth +G_P +P_G ) | |
def Tversky_loss(y_true, y_pred): | |
return -Tversky(y_true, y_pred) | |
return Tversky, Tversky_loss | |
def dice_coef(y_true, y_pred): | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
intersection = K.sum(y_true_f * y_pred_f) | |
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) | |
def dice_coef_loss(y_true, y_pred): | |
return -dice_coef(y_true, y_pred) | |
def get_W_Bin_C_Entropy(class_W): | |
Eps = 1e-10 | |
def W_Bin_C_Entropy(y_true, y_pred): | |
y_true_f = K.flatten(y_true) | |
y_pred_f = K.flatten(y_pred) | |
return -( class_W[1]*K.sum( y_true_f*K.log(y_pred_f+Eps ) ) + class_W[0]*K.sum( (1-y_true_f)*K.log(1-y_pred_f+Eps) ) ) | |
return W_Bin_C_Entropy | |
def get_class_weights(target): | |
target=target/target.max() | |
class_frequencies = np.array([ np.sum(target == 0), np.sum(target == 1) ]) | |
class_weights = (class_frequencies[[1,0]])**0.25 | |
class_weights = class_weights / np.sum(class_weights) * 2. | |
class_weights = class_weights.astype(np.float32) | |
return class_weights | |
#%% | |
def model_eval(y_true,IMGS,model,S=None): | |
scores=[] | |
for (msk_slice,IMG) in zip(y_true,IMGS): | |
temp_msk,temp_img=np.zeros( (1,msk_slice.shape[0],msk_slice.shape[1],1) ),np.zeros( (1,msk_slice.shape[0],msk_slice.shape[1],1) ) | |
temp_msk[0,:,:,:],temp_img[0,:,:,:]=msk_slice,IMG | |
scores.append( np.array( model.evaluate(temp_img, temp_msk, batch_size=1, verbose=0) ) ) | |
mean_scores = np.mean(np.array( scores),axis=0) | |
std_scores = np.std(np.array( scores),axis=0) | |
for mean_score,std_score,metric in zip(mean_scores,std_scores,model.metrics_names): | |
print '{0} mean: {1} std {2}'.format(metric,mean_score,std_score) | |
''' | |
def random_guess_metrics(y_true,S=None): | |
scores=[] | |
for (msk_slice,IMG) in zip(y_true,IMGS): | |
temp=[] | |
random_guess=np.random.rand( msk_slice.shape[0], msk_slice.shape[1], msk_slice.shape[2] ) | |
scores.append(temp ) | |
mean_scores = np.mean(np.array( scores),axis=0) | |
std_scores = np.std(np.array( scores),axis=0) | |
for mean_score,std_score,metric in zip(mean_scores,std_scores,model.metrics_names): | |
print '{0} mean: {1} std {2}'.format(metric,mean_score,std_score) | |
''' | |
#%% data utils | |
def Load_Model(fname,extra_metrics={}): | |
Custon_Objects = {'dice_coef_loss': dice_coef_loss,\ | |
'FNR':FNR,'FPR':FPR,'TPR':TPR,'TNR':TNR, 'Precision':Precision ,\ | |
'FP':FP,'FN':FN,'TP':TP,'TN':TN} | |
Custon_Objects.update(extra_metrics) | |
model = load_model(fname,\ | |
custom_objects=Custon_Objects) | |
return model | |
def print_info(imgs,target,lbl): | |
print('{0}: img / target').format(lbl) | |
print('shape: {0}/{1}'.format(imgs.shape, target.shape) ) | |
print('max-min {0}-{1} , {2}-{3}'.format(imgs.max(),imgs.min(),target.max(),target.min() ) ) | |
print('dtype {0}/{1}'.format(imgs.dtype , target.dtype ) ) | |
print('---------------', ' \n ') | |
def recomb_data(data_train,target_train,patient_train,data_valid,target_valid,patient_valid,R=1,Test_size=.08,SPLTS=1): | |
data = np.concatenate( [data_train, data_valid] ,axis=0) | |
target = np.concatenate( [target_train, target_valid] ,axis=0) | |
patient_ID = np.concatenate( [patient_train, patient_valid] ) | |
return data,target,patient_ID | |
def get_splits_idxs(data,pid,Nsplits=5,R=1,Test_size=0.08): | |
GSS = GroupShuffleSplit(n_splits=Nsplits, test_size=Test_size, random_state=R) | |
#Train Test Split | |
SPLIT_IDXS =[(Train_idx,Test_idx) for Train_idx,Test_idx in GSS.split(list(data),groups=pid)] | |
return SPLIT_IDXS | |
def split_data(data,target,patient_ID,IDXS): | |
Train_idx,Test_idx = IDXS | |
data_train = data[Train_idx] | |
target_train = target[Train_idx] | |
patient_train = patient_ID[Train_idx] | |
data_valid = data[Test_idx] | |
target_valid = target[Test_idx] | |
patient_valid = patient_ID[Test_idx] | |
return data_train,target_train,patient_train,data_valid,target_valid,patient_valid | |
#%% data preprocessing | |
def prep(train_imgs,train_mask,val_imgs,val_mask): | |
train_imgs,train_mask,val_imgs,val_mask=train_imgs.astype(np.float32),train_mask.astype(np.float32),val_imgs.astype(np.float32),val_mask.astype(np.float32) | |
train_imgs,train_mask,val_imgs,val_mask=train_imgs/255,train_mask/255,val_imgs/255,val_mask/255 | |
data_mean = np.mean(train_imgs) | |
data_std = np.std(train_imgs) | |
train_imgs,val_imgs=train_imgs-data_mean , val_imgs-data_mean | |
train_imgs,val_imgs=train_imgs/data_std,val_imgs/data_std | |
return train_imgs,train_mask.astype(np.bool ),val_imgs,val_mask.astype(np.bool ) | |
def normalize(train_imgs,train_mask,val_imgs,val_mask): | |
train_imgs,val_imgs=train_imgs.astype(np.float32),val_imgs.astype(np.float32) | |
if not(train_mask.dtype=='bool') or not(val_mask.dtype=='bool'): | |
train_mask,val_mask=train_mask.astype(np.float32),val_mask.astype(np.float32) | |
train_mask,val_mask=train_mask/train_mask.max(),val_mask/val_mask.max() | |
train_mask,val_mask=train_mask.astype(np.bool),val_mask.astype(np.bool) | |
data_mean = np.mean(train_imgs) | |
data_std = np.std(train_imgs) | |
train_imgs,val_imgs=train_imgs-data_mean , val_imgs-data_mean | |
train_imgs,val_imgs=train_imgs/data_std,val_imgs/data_std | |
return train_imgs,train_mask.astype(np.bool ),val_imgs,val_mask.astype(np.bool ) | |
def prep_minimal(train_imgs,train_mask,val_imgs,val_mask): | |
train_mask,val_mask = (train_mask/255).astype(np.bool ), (val_mask/255).astype(np.bool ) | |
return train_imgs,train_mask,val_imgs,val_mask | |
def prep_int8(train_imgs,train_mask,val_imgs,val_mask): | |
train_imgs, val_imgs = (train_imgs/2.), (val_imgs/2.) | |
mean = train_imgs.mean() | |
train_imgs= ((train_imgs- mean ).round()).astype(np.int8) | |
val_imgs= ((val_imgs-mean ).round()).astype(np.int8) | |
train_mask,val_mask = (train_mask/255).astype(np.bool ), (val_mask/255).astype(np.bool ) | |
return train_imgs,train_mask,val_imgs,val_mask | |
def prep_int8_aug(train_imgs,train_mask): | |
train_imgs = (train_imgs/2.) | |
mean = train_imgs.mean() | |
train_imgs= ((train_imgs- mean ).round()).astype(np.int8) | |
train_mask = (train_mask/255).astype(np.bool ) | |
return train_imgs,train_mask | |
def shuffle(data,target): | |
data_order = range(data.shape[0]) | |
np.random.shuffle(data_order) | |
shuffled_imgs = np.take(data,data_order,axis=0) | |
shuffled_masks = np.take(target,data_order,axis=0) | |
return shuffled_imgs,shuffled_masks | |
def pkl(obj,fname): | |
output = open(fname+'.pkl', 'wb') | |
pickle.dump(obj,output) | |
output.close() | |
#%% plot utils | |
def plot_TLoss(hist): | |
loss = hist['loss'] | |
plt.plot(loss) | |
plt.xlabel('epochs') | |
plt.ylabel('loss (-dice coef)') | |
plt.title('training') | |
def plt_pts(lst, xlabel='epochs',ylabel='loss (-dice coef)',title='pre-training'): | |
plt.plot(lst) | |
plt.xlabel(xlabel) | |
plt.ylabel(ylabel) | |
plt.title(title) | |
def plot_loss(hist): | |
val_loss = hist['val_loss'] | |
loss = hist['loss'] | |
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) | |
ax1.plot(loss) | |
ax1.set_xlabel('epochs') | |
ax1.set_ylabel('loss (-dice coef)') | |
ax1.set_title('training') | |
ax2.plot(val_loss) | |
ax2.set_title('validation') | |
ax2.set_xlabel('epochs') | |
ax2.set_ylabel('loss (-dice coef)') | |
plt.show() | |
return | |
def overlay(im,seg,fs=(40,40)): | |
mask = seg>0 | |
masked = np.ma.masked_where(mask == 0, mask) | |
plt.figure(f ) | |
plt.subplot(1,2,1) | |
plt.imshow(im, 'gray', interpolation='none') | |
plt.subplot(1,2,2) | |
plt.imshow(im, 'gray', interpolation='none') | |
plt.imshow(masked, 'jet', interpolation='none', alpha=0.5) | |
plt.show() | |
def plot_some_data(shuffled_imgs,shuffled_masks,n=3,Figsize=(20,10),Sride=16): | |
print('plot train data') | |
stride=Sride | |
offset = 1 | |
pcounter=1 | |
for i in range(offset, offset+n*stride, stride): | |
plt.figure(figsize=Figsize) | |
# display image | |
ax = plt.subplot(2, n, pcounter) | |
pcounter += 1 | |
im = np.squeeze(shuffled_imgs[i+offset,:,:,:]) | |
plt.imshow(im,cmap='gray') | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
#plt.show() | |
# display predicted | |
ax = plt.subplot(2, n, pcounter) | |
pcounter += 1 | |
im = np.squeeze(shuffled_masks[i+offset,:,:,:]) | |
plt.imshow(im,cmap='gray') | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
plt.show() | |
def plot_results(shuffled_imgs_val,shuffled_masks_val,img_msk_prediction,n=7,Figsize=(50,25)): | |
stride=5 | |
offset = 2 | |
pcounter=1 | |
for i in range(offset, offset+n*stride, stride): | |
plt.figure(figsize=Figsize) | |
# display image | |
ax = plt.subplot(2, n, pcounter) | |
pcounter += 1 | |
im = np.squeeze(shuffled_imgs_val[i+offset,:,:,:]) | |
plt.imshow(im,cmap='gray') | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
#plt.show() | |
# display predicted | |
ax = plt.subplot(2, n, pcounter) | |
pcounter += 1 | |
im = np.squeeze(img_msk_prediction[i+offset,:,:,:]) | |
plt.imshow(im,cmap='gray') | |
plt.contour(shuffled_masks_val[i+offset,:,:,0],cols=('r')) | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
plt.show() | |
#%% | |
def plot_some_Results(data,ground_truth,target,n=5,Figsize=(20,40),Sride=5,P=0): | |
print('plot train data') | |
stride=Sride | |
offset = 1 | |
pcounter=1 | |
for i in range(offset, offset+n*stride, stride): | |
#plt.figure(figsize=Figsize) | |
# display image | |
#ax = plt.subplot(2, n, pcounter) | |
pcounter += 1 | |
im = np.squeeze(data[i+offset,:,:,:]) | |
grd = np.squeeze(ground_truth[i+offset,:,:,:]) | |
seg = np.squeeze(target[i+offset,:,:,:]) | |
if P: | |
Overlay_Cont(im,grd,seg,figs=Figsize) | |
else: | |
show_seg(im,grd,seg,figs=Figsize) | |
def sample_plot_results( data,ground_truth,target,n=5,Figsize=(20,40),P=0): | |
N = data.shape[0] | |
samples = np.random.randint(0,high=N,size=n) | |
for s in samples: | |
im = np.squeeze(data[s,:,:,:]) | |
grd = np.squeeze(ground_truth[s,:,:,:]) | |
seg = np.squeeze(target[s,:,:,:]) | |
show_seg(im,grd,seg,figs=Figsize) | |
if P: | |
Overlay_Cont(im,grd,seg,figs=Figsize) | |
else: | |
show_seg(im,grd,seg,figs=Figsize) | |
def sample_results_( data,ground_truth,target,n=5,Figsize=(20,40)): | |
N = data.shape[0] | |
samples = np.random.randint(0,high=N,size=n) | |
for s in samples: | |
im = np.squeeze(data[s,:,:,:]) | |
grd = np.squeeze(ground_truth[s,:,:,:]) | |
seg = np.squeeze(target[s,:,:,:]) | |
compare_predWcontour(im,grd,seg,figs=Figsize) | |
def sample_plot(data,target,n=10,Figs=10): | |
N = data.shape[0] | |
samples = np.random.randint(0,high=N,size=n) | |
for s in samples: | |
print 'slice {0}'.format(s) | |
Overlay(data[s,:,:,0],target[s,:,:,0],figs=(Figs,Figs) ) | |
def Overlay(im,seg,figs): | |
mask = seg>0 | |
masked = np.ma.masked_where(mask == 0, mask) | |
plt.figure(figsize=figs) | |
ax = plt.subplot(1,2,1) | |
plt.imshow(im, 'viridis', interpolation='none') | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
ax = plt.subplot(1,2,2) | |
plt.imshow(im, 'gray', interpolation='none') | |
plt.imshow(masked, 'autumn', interpolation='none', alpha=0.2) | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
plt.show() | |
def Overlay_Cont(im,grd,seg,figs,thresh=0.2): | |
mask = seg>thresh | |
masked = np.ma.masked_where(mask == 0, mask) | |
plt.figure(figsize=figs) | |
ax = plt.subplot(1,2,1) | |
plt.imshow(im, 'gray', interpolation='none') | |
plt.contour(grd,cols=('r')) | |
ax.set_xlabel('ground truth') | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
ax = plt.subplot(1,2,2) | |
plt.imshow(im, 'gray', interpolation='none') | |
plt.imshow(masked, 'autumn', interpolation='none', alpha=0.2) | |
ax.set_xlabel ('pred') | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
plt.show() | |
def show_seg(im,grd,seg,figs,thresh=0.2): | |
mask = seg>thresh | |
masked = np.ma.masked_where(mask == 0, mask) | |
plt.figure(figsize=(figs[0],3*figs[0]) ) | |
ax = plt.subplot(1,2,1) | |
plt.imshow(im, 'gray', interpolation='none') | |
plt.contour(grd,cols=('r')) | |
ax.set_xlabel('ground truth') | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
ax = plt.subplot(1,2,2) | |
plt.imshow(seg, 'jet', interpolation='none') | |
ax.set_xlabel ('pred') | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
plt.show() | |
#ax = plt.subplot(1,3,3) | |
plt.figure(figsize=(figs[0]/3,figs[0]/3) ) | |
plt.imshow(im, 'gray', interpolation='none') | |
plt.imshow(masked, 'autumn', interpolation='none', alpha=0.2) | |
plt.xlabel ('pred') | |
#plt.xaxis().set_visible(False) | |
#plt.yaxis().set_visible(False) | |
plt.show() | |
def compare_predWcontour(im,grd,seg,figs,thresh=0.2): | |
mask = seg>thresh | |
masked = np.ma.masked_where(mask == 0, mask) | |
plt.figure(figsize=(figs[0],2*figs[0]) ) | |
ax = plt.subplot(1,2,1) | |
plt.imshow(im, 'gray', interpolation='none') | |
plt.imshow(masked, 'autumn', interpolation='none', alpha=0.2) | |
plt.contour(grd,cols=('r')) | |
ax.set_xlabel('ground truth (contour)') | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
ax = plt.subplot(1,2,2) | |
plt.imshow(seg, 'viridis', interpolation='none') | |
ax.set_xlabel ('pred') | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
plt.show() | |
def plot2(img1,img2,figs=(30,60)): | |
plt.figure(figsize=figs) | |
ax = plt.subplot(1,2,1) | |
plt.imshow(img1, 'gray', interpolation='none') | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
ax = plt.subplot(1,2,2) | |
plt.imshow(img2, 'gray', interpolation='none') | |
ax.get_xaxis().set_visible(False) | |
ax.get_yaxis().set_visible(False) | |
plt.show() | |
def plotData2(data1,data2,n=5,Figsize=(30,60),Sride=5): | |
print('plot train data') | |
stride=Sride | |
offset = 0 | |
pcounter=1 | |
for i in range(offset, offset+n*stride, stride): | |
pcounter += 1 | |
img1 = np.squeeze(data[i+offset,:,:,:]) | |
img2 = np.squeeze(target[i+offset,:,:,:]) | |
plot2(img1,img2,figs=Figsize) | |
def plot_some_Data(data,target,n=5,Figsize=(20,40),Sride=5): | |
print('plot train data') | |
stride=Sride | |
offset = 1 | |
pcounter=1 | |
for i in range(offset, offset+n*stride, stride): | |
pcounter += 1 | |
im = np.squeeze(data[i+offset,:,:,:]) | |
seg = np.squeeze(target[i+offset,:,:,:]) | |
Overlay(im,seg,figs=Figsize) | |
#%% | |
def get_Test_Data(Data_Path): | |
os.chdir(Data_Path) | |
data = np.load('data_test.npy') | |
target = np.load('target_test.npy') | |
patient_id = np.load('patient_test.npy') | |
print_info(data,target,'all') | |
_,img_rows,img_cols,_ = data.shape | |
class_freq = {0:1,1: np.sum(target/255.)/(data.shape[0]*data.shape[1]*data.shape[2] ) } | |
print 'class_freq {0}'.format( class_freq ) | |
return data,target,patient_id,img_rows,img_cols,class_freq | |
def get_Data(Data_Path): | |
os.chdir(Data_Path) | |
datainfofile = open('DataDescription.txt','r') | |
data_info = datainfofile.read() | |
datainfofile.close() | |
data = np.load('data_train.npy') | |
target = np.load('target_train.npy') | |
patient_id = np.load('patient_train.npy') | |
print_info(data,target,'all') | |
_,img_rows,img_cols,_ = data.shape | |
class_freq = {0:1,1: np.sum(target/255.)/(data.shape[0]*data.shape[1]*data.shape[2] ) } | |
print 'class_freq {0}'.format( class_freq ) | |
return data,target,patient_id,img_rows,img_cols,class_freq,data_info | |
def get_test_Data(Data_Path): | |
os.chdir(Data_Path) | |
datainfofile = open('DataDescription.txt','r') | |
data_info = datainfofile.read() | |
datainfofile.close() | |
data = np.load('data_test.npy') | |
target = np.load('target_test.npy') | |
patient_id = np.load('patient_test.npy') | |
print_info(data,target,'all') | |
_,img_rows,img_cols,_ = data.shape | |
class_freq = {0:1,1: np.sum(target/255.)/(data.shape[0]*data.shape[1]*data.shape[2] ) } | |
print 'class_freq {0}'.format( class_freq ) | |
return data,target,patient_id,img_rows,img_cols,class_freq | |
def get_data(Data_Path): | |
os.chdir(Data_Path) | |
datainfofile = open('DataDescription.txt','r') | |
data_info = datainfofile.read() | |
datainfofile.close() | |
data_train = np.load('data_train.npy') | |
target_train = np.load('target_train.npy') | |
patient_train = np.load('patient_train.npy') | |
#print_info(data_train,target_train,'train') | |
print('-----------') | |
data_valid = np.load('data_val.npy') | |
target_valid = np.load('target_val.npy') | |
patient_valid = np.load('patient_val.npy') | |
#print_info(data_valid,target_valid,'validation') | |
print('-----------') | |
data,target,patient_id = recomb_data(data_train,target_train,patient_train,data_valid,target_valid,patient_valid) | |
print_info(data,target,'all') | |
_,img_rows,img_cols,_ = data.shape | |
class_freq = {0:1,1: np.sum(target/255.)/(data.shape[0]*data.shape[1]*data.shape[2] ) } | |
print 'class_freq {0}'.format( class_freq ) | |
return data,target,patient_id,img_rows,img_cols,class_freq | |
#%% | |
def Get_BRATS(path_brats=''): | |
if not(path_brats): | |
path_brats = '/home/amourav/Brain Data/Processed for UNET/Brats 512/' | |
LGG_data = np.load(path_brats+'LGG_data.npy') | |
LGG_target = np.load(path_brats+'LGG_target.npy') | |
HGG_data = np.load(path_brats+'HGG_data.npy') | |
HGG_target = np.load(path_brats+'HGG_target.npy') | |
return np.vstack( (HGG_data,LGG_data) ) , np.vstack( (HGG_target,LGG_target) ) | |
#%% | |
def CropArr(np_im,shape): | |
""" Takes an image and pads or crops to fit standard volume | |
----------- | |
np_im: ndarr | |
shape: [nslices,nrows,ncols] | |
Returns | |
----------- | |
np_padded : new numpy array | |
""" | |
np_padded = np.zeros(shape, dtype = np_im.dtype) | |
old_shape = np_im.shape | |
pad_llims = np.zeros([3],dtype=np.uint16) | |
pad_ulims = np.zeros([3],dtype=np.uint16) | |
old_llims = np.zeros([3],dtype=np.uint16) | |
old_ulims = np.zeros([3],dtype=np.uint16) | |
D = len(np_im.shape) | |
for i in range(D): | |
if shape[i] < old_shape[i]: # need to crop input image | |
pad_llims[i] = 0 | |
pad_ulims[i] = shape[i] | |
crop = int((old_shape[i]-shape[i])/2) | |
old_llims[i]=crop | |
old_ulims[i]=crop + shape[i] | |
elif shape[i] == old_shape[i]: # need to crop input image | |
pad_llims[i] = 0 | |
pad_ulims[i] = shape[i] | |
old_llims[i]=0 | |
old_ulims[i]=shape[i] | |
else: | |
old_llims[i] = 0 | |
old_ulims[i] = old_shape[i] | |
pad = int((shape[i]-old_shape[i])/2) | |
pad_llims[i]=pad | |
pad_ulims[i]= pad + old_shape[i] | |
np_padded[pad_llims[0]:pad_ulims[0], | |
pad_llims[1]:pad_ulims[1]] = \ | |
np_im[old_llims[0]:old_ulims[0], | |
old_llims[1]:old_ulims[1]] | |
return np_padded | |
def augment(img,msk,Type, random_state=None,verbose=False): | |
# RANGE FOR PARAMETERS OF DATA AUGMENTATION | |
if Type=='conservative': | |
FlipH = np.random.randint(2,size=1) | |
FlipV = np.random.randint(1,size=1) | |
SHFT_row, SHFT_col = np.random.randint(-60,10), np.random.randint(-60,60) | |
theta = np.random.randint(-10,10) | |
zoom = .9 +0.2*np.random.rand() | |
shear_angle = 0#-20 + 40*np.random.rand() | |
gam = .8 +0.4*np.random.rand() | |
alpha, sigma = 0,0#720,24 # Elastic Deformation params | |
elif Type=='liberal': | |
FlipH = np.random.randint(2,size=1) | |
FlipV = np.random.randint(2,size=1) | |
SHFT_row, SHFT_col = np.random.randint(-90,10), np.random.randint(-80,80) | |
theta = np.random.randint(-20,20) | |
zoom = .9 +0.2*np.random.rand() | |
shear_angle = -20 + 40*np.random.rand() | |
gam = .8 +0.4*np.random.rand() | |
alpha, sigma = 720,24 # Elastic Deformation params | |
if random_state is None: | |
random_state = np.random.RandomState(None) | |
stdr_size = img.shape | |
if verbose: | |
print 'shift: ', SHFT_row, SHFT_col, '\r\n', 'theta {0}'.format(theta), '\r\n' ,'zoom {0}'.format(zoom), \ | |
'\r\n', 'shear {0}'.format(shear_angle), '\r\n','gamma {0}'.format( gam), \ | |
'\r\n','alh {0}, sig {1}'.format(alpha, sigma) | |
## APPLY TRANSFORMATIONS | |
#Flip H | |
if FlipH: | |
img_FlipH=img[:,::-1] | |
msk_FlipH=msk[:,::-1] | |
elif FlipH==0: | |
img_FlipH=img | |
msk_FlipH=msk | |
#Flip V | |
if FlipV: | |
img_FlipV=img_FlipH[::-1,:] | |
msk_FlipV=msk_FlipH[::-1,:] | |
elif FlipV==0: | |
img_FlipV=img_FlipH | |
msk_FlipV=msk_FlipH | |
#Shift | |
img_shft = scipy.ndimage.interpolation.shift( img_FlipV,[SHFT_row, SHFT_col] ) | |
msk_shft = scipy.ndimage.interpolation.shift( msk_FlipV,[SHFT_row, SHFT_col] ) | |
#Rotate | |
img_rot = scipy.ndimage.interpolation.rotate( img_shft,theta ,reshape =False) | |
msk_rot = scipy.ndimage.interpolation.rotate( msk_shft,theta ,reshape =False,order =0) | |
#zoom | |
img_zoom = scipy.ndimage.interpolation.zoom( img_rot , (zoom,zoom) ) | |
img_C = CropArr(img_zoom , stdr_size) | |
msk_zoom = scipy.ndimage.interpolation.zoom( msk_rot , (zoom,zoom) ,mode='nearest',order=0) | |
msk_C = CropArr(msk_zoom , stdr_size) | |
#Shear | |
shftX = stdr_size[0]*np.sin( shear_angle*np.pi/180)/2 | |
shear_tf = tr.AffineTransform(translation=[shftX,0] ,shear = shear_angle*np.pi/180 ) | |
img_shear = (255*tr.warp(img_C, shear_tf) ).astype(np.uint8) | |
msk_shear = (255*(tr.warp(msk_C, shear_tf , order=3)>.5) ).astype(np.uint8) | |
#gamma / | |
img_G = (255*( (img_shear/255.)**gam ) ).astype(np.uint8) | |
#Elastic Deformation | |
dx = gaussian_filter((random_state.rand(*stdr_size) * 2 - 1), sigma, mode="constant", cval=0) * alpha | |
dy = gaussian_filter((random_state.rand(*stdr_size) * 2 - 1), sigma, mode="constant", cval=0) * alpha | |
x, y = np.meshgrid(np.arange(stdr_size[0]), np.arange(stdr_size[1]), indexing='ij') | |
indices = np.reshape(x+dx, (-1, 1)), np.reshape(y+dy, (-1, 1)) | |
img_e = map_coordinates(img_G, indices, order=1).reshape(stdr_size) | |
msk_e = map_coordinates(msk_shear, indices, order=0).reshape(stdr_size) | |
return img_e, msk_e | |
def aug_liberal(img,msk, random_state=None,verbose=False): | |
# RANGE FOR PARAMETERS OF DATA AUGMENTATION | |
FlipH = np.random.randint(2,size=1) | |
FlipV = np.random.randint(2,size=1) | |
SHFT_row, SHFT_col = np.random.randint(-90,10), np.random.randint(-80,80) | |
theta = np.random.randint(-20,20) | |
zoom = .9 +0.2*np.random.rand() | |
shear_angle = -20 + 40*np.random.rand() | |
gam = .8 +0.4*np.random.rand() | |
alpha, sigma = 720,24 # Elastic Deformation params | |
if random_state is None: | |
random_state = np.random.RandomState(None) | |
stdr_size = img.shape | |
if verbose: | |
print 'shift: ', SHFT_row, SHFT_col, '\r\n', 'theta {0}'.format(theta), '\r\n' ,'zoom {0}'.format(zoom), \ | |
'\r\n', 'shear {0}'.format(shear_angle), '\r\n','gamma {0}'.format( gam), \ | |
'\r\n','alh {0}, sig {1}'.format(alpha, sigma) | |
## APPLY TRANSFORMATIONS | |
#Flip H | |
if FlipH: | |
img_FlipH=img[:,::-1] | |
msk_FlipH=msk[:,::-1] | |
elif FlipH==0: | |
img_FlipH=img | |
msk_FlipH=msk | |
#Flip V | |
if FlipV: | |
img_FlipV=img_FlipH[::-1,:] | |
msk_FlipV=msk_FlipH[::-1,:] | |
elif FlipV==0: | |
img_FlipV=img_FlipH | |
msk_FlipV=msk_FlipH | |
#Shift | |
img_shft = scipy.ndimage.interpolation.shift( img_FlipV,[SHFT_row, SHFT_col] ) | |
msk_shft = scipy.ndimage.interpolation.shift( msk_FlipV,[SHFT_row, SHFT_col] ) | |
#Rotate | |
img_rot = scipy.ndimage.interpolation.rotate( img_shft,theta ,reshape =False) | |
msk_rot = scipy.ndimage.interpolation.rotate( msk_shft,theta ,reshape =False,order =0) | |
#zoom | |
img_zoom = scipy.ndimage.interpolation.zoom( img_rot , (zoom,zoom) ) | |
img_C = CropArr(img_zoom , stdr_size) | |
msk_zoom = scipy.ndimage.interpolation.zoom( msk_rot , (zoom,zoom) ,mode='nearest',order=0) | |
msk_C = CropArr(msk_zoom , stdr_size) | |
#Shear | |
shftX = stdr_size[0]*np.sin( shear_angle*np.pi/180)/2 | |
shear_tf = tr.AffineTransform(translation=[shftX,0] ,shear = shear_angle*np.pi/180 ) | |
img_shear = (255*tr.warp(img_C, shear_tf) ).astype(np.uint8) | |
msk_shear = (255*(tr.warp(msk_C, shear_tf , order=3)>.5) ).astype(np.uint8) | |
#gamma / | |
img_G = (255*( (img_shear/255.)**gam ) ).astype(np.uint8) | |
#Elastic Deformation | |
dx = gaussian_filter((random_state.rand(*stdr_size) * 2 - 1), sigma, mode="constant", cval=0) * alpha | |
dy = gaussian_filter((random_state.rand(*stdr_size) * 2 - 1), sigma, mode="constant", cval=0) * alpha | |
x, y = np.meshgrid(np.arange(stdr_size[0]), np.arange(stdr_size[1]), indexing='ij') | |
indices = np.reshape(x+dx, (-1, 1)), np.reshape(y+dy, (-1, 1)) | |
img_e = map_coordinates(img_G, indices, order=1).reshape(stdr_size) | |
msk_e = map_coordinates(msk_shear, indices, order=0).reshape(stdr_size) | |
return img_e, msk_e | |
def make_AUG(AUG_DATA,Type,data,target,patient_id,N_BATCHES=3): | |
print N_BATCHES | |
os.mkdir(AUG_DATA) | |
print 'Generating Augmentations...' | |
np.save(AUG_DATA+os.sep+'aug_patientID.npy', patient_id) | |
for j in range(N_BATCHES): | |
data_aug, target_aug = np.zeros_like(data), np.zeros_like(target) | |
print 'batch {0} / {1}'.format(j+1,N_BATCHES) | |
for i,(IMG,SEG) in enumerate(zip(data,target)): | |
img_aug, msk_aug = augment(IMG[:,:,0],SEG[:,:,0],Type) | |
data_aug[i,:,:,0], target_aug[i,:,:,0] = img_aug, msk_aug | |
if i%300==0: print '{0}/{1}'.format(i,data.shape[0]) | |
np.save(AUG_DATA+os.sep+'aug_data_{1}_{0}.npy'.format(j,Type),data_aug ) | |
np.save(AUG_DATA+os.sep+'aug_target_{1}_{0}.npy'.format(j,Type),target_aug ) | |
# np.save(AUG_DATA+os.sep+'aug_patientID.npy', patient_id) | |
def get_AUG_data(AUG_DATA,N_batches,Type): | |
AUGX_list, AUGY_list = [],[] | |
for j in range(N_batches): | |
data_aug=np.load(AUG_DATA+os.sep+'aug_data_{1}_{0}.npy'.format(j,Type), ) | |
target_aug=np.load(AUG_DATA+os.sep+'aug_target_{1}_{0}.npy'.format(j,Type) ) | |
AUGX_list.append(data_aug) | |
AUGY_list.append(target_aug) | |
patient_id_aug=np.load(AUG_DATA+os.sep+'aug_patientID.npy') | |
return AUGX_list,AUGY_list,patient_id_aug | |
def get_AUG(AUG_DATA,data,target,patient_id,Type,N_batches=3): | |
if os.path.isdir(AUG_DATA): | |
AUGX_list,AUGY_list,patient_id_aug=get_AUG_data(AUG_DATA,N_batches,Type) | |
else: | |
make_AUG(AUG_DATA,Type,data,target,patient_id) | |
AUGX_list,AUGY_list,patient_id_aug = get_AUG_data(AUG_DATA,Type,N_batches) | |
return AUGX_list,AUGY_list,patient_id_aug | |
def combine_augdata(data,target,patient_id,AUGX_list,AUGY_list,patient_id_aug): | |
data_plus,target_plus,patient_id_plus = data,target,patient_id | |
for i,(augX,augY) in enumerate(zip(AUGX_list,AUGY_list)): | |
data_plus,target_plus=np.vstack( (data_plus,augX) ), np.vstack( (target_plus,augY) ) | |
patient_id_plus=np.concatenate( (patient_id_plus,patient_id_aug) ) | |
return data_plus,target_plus,patient_id_plus | |
def AUGMENT_DATA(AUG_DATA,TYPE,data,target,patient_id,N_Batches=1): | |
AUGX_list,AUGY_list,patient_id_aug=get_AUG(AUG_DATA,data,target,patient_id,TYPE,N_batches=N_Batches) | |
data_plus,target_plus,patient_id_plus = combine_augdata(data,target,patient_id,AUGX_list,AUGY_list,patient_id_aug) | |
return data_plus,target_plus ,patient_id_plus | |
def shw(im,Cmap='gray',Figs=(5,5)): | |
plt.figure(figsize = Figs) | |
plt.imshow(im,cmap=Cmap) | |
plt.show() | |
def set_gpu_limit(lim): | |
#os.environ["TENSORFLOW_FLAGS"] = "device=gpu6" | |
from keras.backend.tensorflow_backend import set_session | |
config = tf.ConfigProto() | |
config.gpu_options.per_process_gpu_memory_fraction = lim | |
set_session(tf.Session(config=config)) | |
def write_eval(IMGS,MSKS,model,i,patients): | |
print '\n validation scores: \n' | |
scores = model.evaluate(IMGS,MSKS, batch_size=1, verbose=0) | |
txtfile = open('model_eval{0}.txt'.format(i),'w') | |
txtfile.write('MODEL EVAL - Patients {0} \r\n'.format(patients)) | |
for score,metric in zip(scores,model.metrics_names): | |
print '{0} score: {1}'.format(metric,score) | |
txtfile.write('{0} score: {1} \r\n'.format(metric,score)) | |
txtfile.close() | |
def save_unet(model,i): | |
i=str(i) | |
pkl(model.history.history,'trainning_hist_{0}'.format(i) ) | |
weights = model.get_weights() | |
np.save('model_weights_{0}.npy'.format(i),weights) | |
json_string = model.to_json() | |
text_file = open("model_json_{0}.txt".format(i), "w") | |
text_file.write(json_string) | |
text_file.close() | |
history = model.history.history | |
history_keys = history.keys() | |
test_file2 = open("model_hist_{0}.txt".format(i), "w") | |
for key in history_keys: | |
test_file2.write('{0}'.format(history[key]) ) | |
test_file2.write('\n' ) | |
test_file2.write('\n' ) | |
test_file2.close() | |
def avg_loss(model_hist): | |
val_loss_avg = [] | |
val_loss_std = [] | |
for hist in model_hist: | |
temp =hist['val_loss'] | |
val_avg,val_std = np.mean(temp[-10:-1] ) , np.std(temp[-10:-1] ) | |
print val_avg | |
val_loss_avg.append( val_avg ) | |
print '\r\n final avg {0} - std {1} - se {2}'.format( np.mean(val_loss_avg), np.std(val_loss_avg ),np.std(val_loss_avg )/(len(model_hist)-1 ) ) | |
class CUSTOMCALL2(keras.callbacks.Callback): | |
def on_train_begin(self, logs={}): | |
self.losses = [] | |
self.val_losses = [] | |
self.counter1 = 0 | |
def on_epoch_end(self, epoch, logs={}): | |
self.losses.append(logs.get('loss')) | |
self.counter1 += 1 | |
print 'epoch {0}'.format(self.counter1) ,'loss {0}'.format( logs.get('loss') ) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment