Skip to content

Instantly share code, notes, and snippets.

@ravnoor
Last active December 21, 2018 15:48
Show Gist options
  • Save ravnoor/771f14996f77c49c3a11d38886fb6a0d to your computer and use it in GitHub Desktop.
Save ravnoor/771f14996f77c49c3a11d38886fb6a0d to your computer and use it in GitHub Desktop.
# https://github.com/andreimouraviev/Mets/blob/a8ce43b335584187f0728a4b3975c679ccf06cbc/UNET_2D_AUG17.py
# coding: utf-8
# In[1]:
import os
os.chdir(r'/home/amourav/Python')
import sklearn
from UTILS import *
K.set_image_dim_ordering('tf')
K.set_image_data_format('channels_last')
os.environ["CUDA_VISIBLE_DEVICES"]="6"
#set_gpu_limit(0.5)
# In[2]:
def get_Unet(img_rows, img_cols,LossF ,Metrics,Optimizer=Adam(1e-5),DropP=0,reg=0.000,batch_norm=False):
#init_W = keras.initializers.RandomNormal(mean=0.001, stddev=0.08, seed=7)
#init_B = keras.initializers.RandomNormal(mean=0.001, stddev=0.004, seed=7)
L2 = keras.regularizers.l2(reg)
print 'Opti {0}, DropP {1}, Loss {2}, reg {3}'.format(Optimizer, DropP, LossF, reg)
inputs = Input(( img_rows, img_cols, 1))
if batch_norm: inputs = BatchNormalization(inputs)
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(inputs)
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
pool1 = Dropout(DropP)(pool1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(pool1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
pool2 = Dropout(DropP)(pool2)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(pool2)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
pool3 = Dropout(DropP)(pool3)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(pool3)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
pool4 = Dropout(DropP)(pool4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(pool4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv5)
up6 = concatenate([Conv2DTranspose(256,(2, 2), strides=(2, 2), padding='same',kernel_regularizer=L2)(conv5), conv4],name='up6', axis=3)
up6 = Dropout(DropP)(up6)
conv6 = Conv2D(256,(3, 3), activation='relu', padding='same',kernel_regularizer=L2)(up6)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv6)
up7 = concatenate([Conv2DTranspose(128,(2, 2), strides=(2, 2), padding='same',kernel_regularizer=L2)(conv6), conv3],name='up7', axis=3)
up7 = Dropout(DropP)(up7)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(up7)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv7)
up8 = concatenate([Conv2DTranspose(64,(2, 2), strides=(2, 2), padding='same',kernel_regularizer=L2)(conv7), conv2],name='up8', axis=3)
up8 = Dropout(DropP)(up8)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(up8)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv8)
up9 = concatenate([Conv2DTranspose(32,(2, 2), strides=(2, 2), padding='same',kernel_regularizer=L2)(conv8), conv1],name='up9',axis=3)
up9 = Dropout(DropP)(up9)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(up9)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same',kernel_regularizer=L2)(conv9)
conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
model = Model(inputs=inputs, outputs=conv10)
model.compile(optimizer=Optimizer, loss=LossF, metrics=Metrics)
return model
# In[3]:
# CALLBACKS
class Print_Loss(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
pass
def on_epoch_end(self, epoch, logs={}):
print '-epoch {0}'.format(epoch) ,'loss {0}'.format( logs.get('loss') ), 'val_loss {0}'.format( logs.get('val_loss') )
def train_UNET(Prep_func,data,target,patient_id,Batch_Norm,Data_Path_out,Optimizer,CALLBACK,Nepochs, Batch_size,Nsplits,lossF,METRICS,Verbose=0,DropP=0,S=1,ClassW=None,shw_plt=0,splitN=None ,SkipTo = None,Nepochs_Pre=0,Sample_Weights=None,Reg=0.000,TESTMODE=False):
if TESTMODE: data,target,patient_id,Nsplits,Nepochs = data[0:100],target[0:100],patient_id[0:100],1,2
os.chdir(Data_Path_out)
MODEL = []
Model_List = []
VAL_PATIENTS = []
INDXS = get_splits_idxs(data,patient_id,Nsplits)
for i,TRAIN_TEST_SPLIT in enumerate( INDXS ):
Train_idx,Test_idx = TRAIN_TEST_SPLIT
print 'shuffle {0} \n'.format(i)
if splitN==i or not(i>=SkipTo):
print 'skipping {0} \n'.format(i)
continue
data_train,target_train,patient_train,data_valid,target_valid,patient_valid = split_data(data,target,patient_id,TRAIN_TEST_SPLIT)
Sample_Weights_trn = None
if not(Sample_Weights is None):Sample_Weights_trn = Sample_Weights[Train_idx]
if os.path.isfile('unet_{0}.hdf5'.format(i) ):
model = Load_Model('unet_{0}.hdf5'.format(i),{'Tversky':Tversky,'W_Bin_C_Entropy':W_Bin_C_Entropy})
print 'loading model..'
else:
model = get_Unet(img_rows,img_cols,lossF,METRICS,Optimizer,DropP,reg=Reg,batch_norm=Batch_Norm)
''' #Pretrain
if Nepochs_Pre>0:
preT_data,PreT_target = get_BRATS()
preT_data_prep,PreT_target_prep,_,_=Prep_func(preT_data,PreT_target,data_valid,target_valid)
custom2 = CUSTOMCALL()
model.fit(preT_data_prep, PreT_target_prep, batch_size=Batch_size, epochs=Nepochs_Pre,\
verbose=Verbose, shuffle=True,callbacks=[custom2],class_weight=None,validation_split=0.1)
'''
#Prep data for unet
data_train_prep,target_train_prep,data_valid_prep,target_valid_prep=Prep_func(data_train,target_train,data_valid,target_valid)
print_info(data_train_prep,target_train_prep,'train')
print_info(data_valid_prep,target_valid_prep,'valid')
# Shuffle
shuffled_imgs,shuffled_masks = shuffle(data_train_prep,target_train_prep)
shuffled_imgs_validation,shuffled_masks_validation = shuffle(data_valid_prep,target_valid_prep)
#if shw_plt: plot_some_Data(shuffled_imgs,shuffled_masks,n=4,Figsize=(30,60),Sride=5)
if shw_plt and i==1: sample_plot(shuffled_imgs,shuffled_masks,n=5,Figs=10)
# Callbacks
if not(Verbose): CALLBACK.append(Print_Loss() )
model_checkpoint = ModelCheckpoint('unet_{0}.hdf5'.format(i), monitor='val_loss', save_best_only=True)
log_loss = Record_LOSS_METRICS()
log_loss.set_split(i)
CALLBACKs = CALLBACK + [model_checkpoint,log_loss]
#Train UNET
t1 = time.time()
print('-'*30,'\n' ,'Fitting model...','\n',('-'*30))
history_callback = model.fit(shuffled_imgs, shuffled_masks, batch_size=Batch_size, epochs=Nepochs,\
verbose=Verbose, shuffle=True,callbacks=CALLBACKs,class_weight=ClassW,\
validation_data=(shuffled_imgs_validation,shuffled_masks_validation),\
sample_weight=Sample_Weights_trn)
t2 = time.time()
#Save Model
if S: save_model(model,i)
print('----------------','\n','train time {0} h'.format( (t2-t1)/3600 ),'\n','----------------')
print 'Patients: {0}'.format( np.unique(patient_valid) ),'\r\n'
VAL_PATIENTS.append(np.unique(patient_valid))
if shw_plt: plot_loss(model.history.history)
img_msk_prediction = model.predict(shuffled_imgs_validation, verbose=0 , batch_size=1)
if shw_plt: sample_results_(shuffled_imgs_validation,shuffled_masks_validation,img_msk_prediction,n=15\
,Figsize=(20,40))
scores = model.evaluate(shuffled_imgs, shuffled_masks, batch_size=1, verbose=0)
for score,metric in zip(scores,model.metrics_names):
print '{0} score: {1}'.format(metric,score)
MODEL.append(scores )
#model_eval(shuffled_masks,img_msk_prediction)
#print 'model eval \n {0} \n '.format(MODEL[0])
MODEL.append(model)
MODEL.append(history_callback.history)
Model_List.append(MODEL)
print '\n','----------------','\n','----------------'
del model
del MODEL
return Model_List
class Record_LOSS_METRICS(keras.callbacks.Callback):
def set_split(self,i_split):
self.i_split=i_split
def on_train_begin(self, logs={}):
FID = open('UNET_loss_{0}.txt'.format(self.i_split), 'w')
FID.write('RECORDING LOSS: \r\n')
FID.close()
FID2 = open('UNET_metrics_{0}.txt'.format(self.i_split), 'w')
FID2.write('RECORDING METRICS: \r\n')
FID2.close()
def on_epoch_end(self, epoch, logs={}):
FID = open('UNET_loss_{0}.txt'.format(self.i_split), 'a')
FID.write('{0} train {1} val {2} \r\n'.format(epoch,logs.get('loss'),logs.get('val_loss')) )
FID.close()
FID2 = open('UNET_metrics_{0}.txt'.format(self.i_split), 'a')
FID2.write('{0} FNR {1} FPR {1} TNR {2} TPR {3} Precision {4} Dice {5} \r\n'.\
format(epoch,logs.get('FNR'),logs.get('FPR'),logs.get('TNR'),logs.get('TPR'),logs.get('Precision'),logs.get('dice_coef_loss') ))
FID2.close()
'''
schedule = 10
decay = .5
def scheduler(epoch):
LR = self.lr.get_value()
if epoch == 1: self.losses_ = []
if epoch == schedule:
self.losses_.append(logs.get('loss'))
if self.losses_[-1]>self.losses_[-2] and self.losses_[-1]>self.losses_[-3]:
LR = decay*self.lr.get_value()
return LR
change_lr = LearningRateScheduler(scheduler)
'''
def get_class_weights(target):
target=target/target.max()
class_frequencies = np.array([ np.sum(target == 0), np.sum(target == 1) ])
class_weights = (class_frequencies[[1,0]])**0.25
class_weights = class_weights / np.sum(class_weights) * 2.
class_weights = class_weights.astype(np.float32)
return class_weights
# In[4]:
Data_Path = r'/home/amourav/Brain Data/Processed for UNET/2D/data_percentile_norm_top_99.9_bottom_0.1_shape_512_skull str'
data,target,patient_id,img_rows,img_cols,class_freq,data_info = get_Data(Data_Path)
class_W = get_class_weights(target)
print class_W
#sample_plot(data,target,n=10,Figs=5)
# In[5]:
#Data Augmentation
AUGMENTATION = False
Aug_batches = 1
if AUGMENTATION:
data,target,patient_id = AUGMENT_DATA('Augmented_conservative','conservative',data,target,patient_id,N_Batches=Aug_batches)
# In[6]:
#%% Loss/Metrics Functions
Tversky, Tversky_loss = get_Tversky(alpha = .3,beta= .7,verb=0)
W_Bin_C_Entropy = get_W_Bin_C_Entropy(class_W)
Metrics = [dice_coef_loss ,Tversky,W_Bin_C_Entropy,Precision,FPR,FNR,TPR,TNR]
LOSS_Func = dice_coef_loss
Data_Path_out = '/home/amourav/Python/Mets/unet/results/data in-{0}'.format(os.path.split(Data_Path)[-1] )
print Data_Path_out
if not os.path.exists(Data_Path_out):
os.makedirs(Data_Path_out)
os.chdir(Data_Path_out)
#PARAMETERS
#Optimizer = keras.optimizers.Adadelta(lr=1.0)
lr= 1e-5
Optimizer=Adam(lr )
Nsplits=5
Batch_size = 10
Nepochs = 100
Nepochs_Pre = 0
Patience = 15
Stop_delta = 0.02
Prep_func=normalize
skip = None #skip split
skipTo = None
Reg_coef=0
DropPerc= 0
Batch_Norm = False
#CALLBACKS
print 'lr {0}'.format(lr)
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=Stop_delta, patience=Patience, verbose=1, mode='auto')
early_stop2 = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.005, patience=10, verbose=1, mode='auto')
reduce_lr=keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4, verbose=1, mode='auto', epsilon=0.001, cooldown=0, min_lr=0)
reduce_lr2=keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.25, patience=6, verbose=1, mode='auto', epsilon=0.01, cooldown=0, min_lr=0)
reduce_lr3=keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=8, verbose=1, mode='auto', epsilon=0.1, cooldown=0, min_lr=0)
reduce_LR=[reduce_lr,reduce_lr2,reduce_lr3]
CALLBACK=[early_stop,early_stop2]+ reduce_LR
testMODE=True
# In[7]:
Model_List = train_UNET(Prep_func,data,target,patient_id, Batch_Norm,Data_Path_out,Optimizer,CALLBACK,Nepochs,Batch_size, Nsplits,lossF=LOSS_Func,METRICS = Metrics,Verbose=0,DropP=DropPerc,S=0,ClassW=None,shw_plt=0, splitN=skip,SkipTo=skipTo, Nepochs_Pre=0,Sample_Weights=None, Reg=Reg_coef,TESTMODE=testMODE)
# In[ ]:
os.chdir(Data_Path_out)
filename = 'UNET_INFO.txt'
target = open(filename, 'w')
target.write('UNET_INFO')
target.write(' \r\n')
target.write('N_folds {0}'.format(Nsplits) )
target.write(' \r\n')
target.write( time.strftime("_%d_%m_%Y_") )
target.write(' \r\n')
target.write('Batch_size {0}'.format( Batch_size) )
target.write(' \r\n')
target.write(' Nepochs {0}'.format(Nepochs) )
target.write(' \r\n')
target.write('Pretrain Nepochs {0}'.format(Nepochs_pre) )
target.write(' \r\n')
target.write(' \r\n')
target.write(' \r\n')
target.write('{0}: img / target').format('training')
target.write(' \r\n')
target.write('shape: {0}/{1}'.format(data_train_prep.shape, target_train_prep.shape) )
target.writeint('max-min {0}-{1} , {2}-{3}'.format(data_train_prep.max(),data_train_prep.min(),target_train_prep.max(),target_train_prep.min() ) )
target.write('dtype {0}/{1}'.format(data_train_prep.dtype , target_train_prep.dtype ) )
target.write(' \r\n')
target.write(' \r\n')
target.write('Augmentation {0}, Nbatches {1}'.format(AUGMENTATION) )
target.write(' \r\n')
target.write('loss {0}'.format(LOSS_Func) )
target.write(' \r\n')
target.write('Data_Info')
target.write( str( data_info) )
target.write(' \r\n')
target.close()
# https://github.com/andreimouraviev/Mets/blob/251322d2ae9561b32ac11af829ab6535e1b53e87/UTILS.py
import numpy as np
import os
import matplotlib.pyplot as plt
import time
import pickle
from sklearn.model_selection import GroupShuffleSplit
import sklearn
import skimage
import skimage.transform as tr
import scipy.ndimage
from scipy.ndimage.interpolation import map_coordinates
from scipy.ndimage.filters import gaussian_filter
from keras import backend as K
import tensorflow as tf
import keras
from keras.layers import Input, Dense, Convolution2D, merge, Conv2D,Conv2DTranspose
from keras.layers import MaxPooling2D, UpSampling2D,Dropout, Flatten,concatenate
from keras.models import Model
from keras.utils import np_utils
from keras import backend as K
from keras.models import model_from_json
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras.regularizers import l1_l2
from keras.layers.normalization import BatchNormalization
import keras
import sklearn.metrics
from keras.models import load_model
#%% SEED
np.random.seed(7)
tf.set_random_seed(7)
print '...Seed Set'
#%% LOSS
smooth = 1.
smooth2 = 1e-5
#def confusion(TRUE_LIST,PRED_LIST):
# for i, (y_true,y_pred) in enumerate(zip(TRUE_LIST,PRED_LIST )):
def FP(y_true,y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
return np.sum( y_pred_f*(1-y_true_f) )
def FN(y_true,y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
return np.sum( (1-y_pred_f)*y_true_f )
def TP(y_true,y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
return np.sum( y_pred_f*y_true_f )
def TN(y_true,y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
return np.sum( (1-y_pred_f)*(1-y_true_f) )
def FPR(y_true,y_pred):
#fallout
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
return 1.0*(K.sum( y_pred_f*(1-y_true_f) ) )/(K.sum(1-y_true_f)+smooth2)
def FNR(y_true,y_pred):
#miss rate
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
return 1.0*( K.sum( (1-y_pred_f)*y_true_f ) )/ ( K.sum( y_true_f ) +smooth2 )
def TPR(y_true,y_pred):
#sensitivity, recall, hit rate
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
return 1.0*( K.sum( y_pred_f*y_true_f ) )/ ( K.sum( y_true_f ) +smooth2 )
def TNR(y_true,y_pred):
#specificity
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
return 1.0*( K.sum( (1-y_pred_f)*(1-y_true_f) ) )/ ( K.sum( 1-y_true_f ) +smooth2 )
#return np.sum( (1.-y_pred_f)*y_true_f )/(1.0*np.sum( y_true_f )+smooth2)
#return 1-(np.sum( y_pred_f*y_true_f )+smooth2)/(1.0*np.sum(y_true_f)+smooth2 ) # 1 - TPR
def Precision(y_true,y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
return (1.0*K.sum( y_pred_f*y_true_f ) +smooth2)/(K.sum( y_pred_f*y_true_f ) + K.sum( y_pred_f*(1.0-y_true_f) )+smooth2)
def get_Tversky(alpha = .3,beta= .7,verb=0):
def Tversky(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
G_P = alpha*K.sum( (1-y_true_f) * y_pred_f ) # G not P
P_G = beta*K.sum( y_true_f * (1-y_pred_f) ) # P not G
return (intersection + smooth )/(intersection + smooth +G_P +P_G )
def Tversky_loss(y_true, y_pred):
return -Tversky(y_true, y_pred)
return Tversky, Tversky_loss
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
def get_W_Bin_C_Entropy(class_W):
Eps = 1e-10
def W_Bin_C_Entropy(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
return -( class_W[1]*K.sum( y_true_f*K.log(y_pred_f+Eps ) ) + class_W[0]*K.sum( (1-y_true_f)*K.log(1-y_pred_f+Eps) ) )
return W_Bin_C_Entropy
def get_class_weights(target):
target=target/target.max()
class_frequencies = np.array([ np.sum(target == 0), np.sum(target == 1) ])
class_weights = (class_frequencies[[1,0]])**0.25
class_weights = class_weights / np.sum(class_weights) * 2.
class_weights = class_weights.astype(np.float32)
return class_weights
#%%
def model_eval(y_true,IMGS,model,S=None):
scores=[]
for (msk_slice,IMG) in zip(y_true,IMGS):
temp_msk,temp_img=np.zeros( (1,msk_slice.shape[0],msk_slice.shape[1],1) ),np.zeros( (1,msk_slice.shape[0],msk_slice.shape[1],1) )
temp_msk[0,:,:,:],temp_img[0,:,:,:]=msk_slice,IMG
scores.append( np.array( model.evaluate(temp_img, temp_msk, batch_size=1, verbose=0) ) )
mean_scores = np.mean(np.array( scores),axis=0)
std_scores = np.std(np.array( scores),axis=0)
for mean_score,std_score,metric in zip(mean_scores,std_scores,model.metrics_names):
print '{0} mean: {1} std {2}'.format(metric,mean_score,std_score)
'''
def random_guess_metrics(y_true,S=None):
scores=[]
for (msk_slice,IMG) in zip(y_true,IMGS):
temp=[]
random_guess=np.random.rand( msk_slice.shape[0], msk_slice.shape[1], msk_slice.shape[2] )
scores.append(temp )
mean_scores = np.mean(np.array( scores),axis=0)
std_scores = np.std(np.array( scores),axis=0)
for mean_score,std_score,metric in zip(mean_scores,std_scores,model.metrics_names):
print '{0} mean: {1} std {2}'.format(metric,mean_score,std_score)
'''
#%% data utils
def Load_Model(fname,extra_metrics={}):
Custon_Objects = {'dice_coef_loss': dice_coef_loss,\
'FNR':FNR,'FPR':FPR,'TPR':TPR,'TNR':TNR, 'Precision':Precision ,\
'FP':FP,'FN':FN,'TP':TP,'TN':TN}
Custon_Objects.update(extra_metrics)
model = load_model(fname,\
custom_objects=Custon_Objects)
return model
def print_info(imgs,target,lbl):
print('{0}: img / target').format(lbl)
print('shape: {0}/{1}'.format(imgs.shape, target.shape) )
print('max-min {0}-{1} , {2}-{3}'.format(imgs.max(),imgs.min(),target.max(),target.min() ) )
print('dtype {0}/{1}'.format(imgs.dtype , target.dtype ) )
print('---------------', ' \n ')
def recomb_data(data_train,target_train,patient_train,data_valid,target_valid,patient_valid,R=1,Test_size=.08,SPLTS=1):
data = np.concatenate( [data_train, data_valid] ,axis=0)
target = np.concatenate( [target_train, target_valid] ,axis=0)
patient_ID = np.concatenate( [patient_train, patient_valid] )
return data,target,patient_ID
def get_splits_idxs(data,pid,Nsplits=5,R=1,Test_size=0.08):
GSS = GroupShuffleSplit(n_splits=Nsplits, test_size=Test_size, random_state=R)
#Train Test Split
SPLIT_IDXS =[(Train_idx,Test_idx) for Train_idx,Test_idx in GSS.split(list(data),groups=pid)]
return SPLIT_IDXS
def split_data(data,target,patient_ID,IDXS):
Train_idx,Test_idx = IDXS
data_train = data[Train_idx]
target_train = target[Train_idx]
patient_train = patient_ID[Train_idx]
data_valid = data[Test_idx]
target_valid = target[Test_idx]
patient_valid = patient_ID[Test_idx]
return data_train,target_train,patient_train,data_valid,target_valid,patient_valid
#%% data preprocessing
def prep(train_imgs,train_mask,val_imgs,val_mask):
train_imgs,train_mask,val_imgs,val_mask=train_imgs.astype(np.float32),train_mask.astype(np.float32),val_imgs.astype(np.float32),val_mask.astype(np.float32)
train_imgs,train_mask,val_imgs,val_mask=train_imgs/255,train_mask/255,val_imgs/255,val_mask/255
data_mean = np.mean(train_imgs)
data_std = np.std(train_imgs)
train_imgs,val_imgs=train_imgs-data_mean , val_imgs-data_mean
train_imgs,val_imgs=train_imgs/data_std,val_imgs/data_std
return train_imgs,train_mask.astype(np.bool ),val_imgs,val_mask.astype(np.bool )
def normalize(train_imgs,train_mask,val_imgs,val_mask):
train_imgs,val_imgs=train_imgs.astype(np.float32),val_imgs.astype(np.float32)
if not(train_mask.dtype=='bool') or not(val_mask.dtype=='bool'):
train_mask,val_mask=train_mask.astype(np.float32),val_mask.astype(np.float32)
train_mask,val_mask=train_mask/train_mask.max(),val_mask/val_mask.max()
train_mask,val_mask=train_mask.astype(np.bool),val_mask.astype(np.bool)
data_mean = np.mean(train_imgs)
data_std = np.std(train_imgs)
train_imgs,val_imgs=train_imgs-data_mean , val_imgs-data_mean
train_imgs,val_imgs=train_imgs/data_std,val_imgs/data_std
return train_imgs,train_mask.astype(np.bool ),val_imgs,val_mask.astype(np.bool )
def prep_minimal(train_imgs,train_mask,val_imgs,val_mask):
train_mask,val_mask = (train_mask/255).astype(np.bool ), (val_mask/255).astype(np.bool )
return train_imgs,train_mask,val_imgs,val_mask
def prep_int8(train_imgs,train_mask,val_imgs,val_mask):
train_imgs, val_imgs = (train_imgs/2.), (val_imgs/2.)
mean = train_imgs.mean()
train_imgs= ((train_imgs- mean ).round()).astype(np.int8)
val_imgs= ((val_imgs-mean ).round()).astype(np.int8)
train_mask,val_mask = (train_mask/255).astype(np.bool ), (val_mask/255).astype(np.bool )
return train_imgs,train_mask,val_imgs,val_mask
def prep_int8_aug(train_imgs,train_mask):
train_imgs = (train_imgs/2.)
mean = train_imgs.mean()
train_imgs= ((train_imgs- mean ).round()).astype(np.int8)
train_mask = (train_mask/255).astype(np.bool )
return train_imgs,train_mask
def shuffle(data,target):
data_order = range(data.shape[0])
np.random.shuffle(data_order)
shuffled_imgs = np.take(data,data_order,axis=0)
shuffled_masks = np.take(target,data_order,axis=0)
return shuffled_imgs,shuffled_masks
def pkl(obj,fname):
output = open(fname+'.pkl', 'wb')
pickle.dump(obj,output)
output.close()
#%% plot utils
def plot_TLoss(hist):
loss = hist['loss']
plt.plot(loss)
plt.xlabel('epochs')
plt.ylabel('loss (-dice coef)')
plt.title('training')
def plt_pts(lst, xlabel='epochs',ylabel='loss (-dice coef)',title='pre-training'):
plt.plot(lst)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
def plot_loss(hist):
val_loss = hist['val_loss']
loss = hist['loss']
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
ax1.plot(loss)
ax1.set_xlabel('epochs')
ax1.set_ylabel('loss (-dice coef)')
ax1.set_title('training')
ax2.plot(val_loss)
ax2.set_title('validation')
ax2.set_xlabel('epochs')
ax2.set_ylabel('loss (-dice coef)')
plt.show()
return
def overlay(im,seg,fs=(40,40)):
mask = seg>0
masked = np.ma.masked_where(mask == 0, mask)
plt.figure(f )
plt.subplot(1,2,1)
plt.imshow(im, 'gray', interpolation='none')
plt.subplot(1,2,2)
plt.imshow(im, 'gray', interpolation='none')
plt.imshow(masked, 'jet', interpolation='none', alpha=0.5)
plt.show()
def plot_some_data(shuffled_imgs,shuffled_masks,n=3,Figsize=(20,10),Sride=16):
print('plot train data')
stride=Sride
offset = 1
pcounter=1
for i in range(offset, offset+n*stride, stride):
plt.figure(figsize=Figsize)
# display image
ax = plt.subplot(2, n, pcounter)
pcounter += 1
im = np.squeeze(shuffled_imgs[i+offset,:,:,:])
plt.imshow(im,cmap='gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
#plt.show()
# display predicted
ax = plt.subplot(2, n, pcounter)
pcounter += 1
im = np.squeeze(shuffled_masks[i+offset,:,:,:])
plt.imshow(im,cmap='gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
def plot_results(shuffled_imgs_val,shuffled_masks_val,img_msk_prediction,n=7,Figsize=(50,25)):
stride=5
offset = 2
pcounter=1
for i in range(offset, offset+n*stride, stride):
plt.figure(figsize=Figsize)
# display image
ax = plt.subplot(2, n, pcounter)
pcounter += 1
im = np.squeeze(shuffled_imgs_val[i+offset,:,:,:])
plt.imshow(im,cmap='gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
#plt.show()
# display predicted
ax = plt.subplot(2, n, pcounter)
pcounter += 1
im = np.squeeze(img_msk_prediction[i+offset,:,:,:])
plt.imshow(im,cmap='gray')
plt.contour(shuffled_masks_val[i+offset,:,:,0],cols=('r'))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
#%%
def plot_some_Results(data,ground_truth,target,n=5,Figsize=(20,40),Sride=5,P=0):
print('plot train data')
stride=Sride
offset = 1
pcounter=1
for i in range(offset, offset+n*stride, stride):
#plt.figure(figsize=Figsize)
# display image
#ax = plt.subplot(2, n, pcounter)
pcounter += 1
im = np.squeeze(data[i+offset,:,:,:])
grd = np.squeeze(ground_truth[i+offset,:,:,:])
seg = np.squeeze(target[i+offset,:,:,:])
if P:
Overlay_Cont(im,grd,seg,figs=Figsize)
else:
show_seg(im,grd,seg,figs=Figsize)
def sample_plot_results( data,ground_truth,target,n=5,Figsize=(20,40),P=0):
N = data.shape[0]
samples = np.random.randint(0,high=N,size=n)
for s in samples:
im = np.squeeze(data[s,:,:,:])
grd = np.squeeze(ground_truth[s,:,:,:])
seg = np.squeeze(target[s,:,:,:])
show_seg(im,grd,seg,figs=Figsize)
if P:
Overlay_Cont(im,grd,seg,figs=Figsize)
else:
show_seg(im,grd,seg,figs=Figsize)
def sample_results_( data,ground_truth,target,n=5,Figsize=(20,40)):
N = data.shape[0]
samples = np.random.randint(0,high=N,size=n)
for s in samples:
im = np.squeeze(data[s,:,:,:])
grd = np.squeeze(ground_truth[s,:,:,:])
seg = np.squeeze(target[s,:,:,:])
compare_predWcontour(im,grd,seg,figs=Figsize)
def sample_plot(data,target,n=10,Figs=10):
N = data.shape[0]
samples = np.random.randint(0,high=N,size=n)
for s in samples:
print 'slice {0}'.format(s)
Overlay(data[s,:,:,0],target[s,:,:,0],figs=(Figs,Figs) )
def Overlay(im,seg,figs):
mask = seg>0
masked = np.ma.masked_where(mask == 0, mask)
plt.figure(figsize=figs)
ax = plt.subplot(1,2,1)
plt.imshow(im, 'viridis', interpolation='none')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(1,2,2)
plt.imshow(im, 'gray', interpolation='none')
plt.imshow(masked, 'autumn', interpolation='none', alpha=0.2)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
def Overlay_Cont(im,grd,seg,figs,thresh=0.2):
mask = seg>thresh
masked = np.ma.masked_where(mask == 0, mask)
plt.figure(figsize=figs)
ax = plt.subplot(1,2,1)
plt.imshow(im, 'gray', interpolation='none')
plt.contour(grd,cols=('r'))
ax.set_xlabel('ground truth')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(1,2,2)
plt.imshow(im, 'gray', interpolation='none')
plt.imshow(masked, 'autumn', interpolation='none', alpha=0.2)
ax.set_xlabel ('pred')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
def show_seg(im,grd,seg,figs,thresh=0.2):
mask = seg>thresh
masked = np.ma.masked_where(mask == 0, mask)
plt.figure(figsize=(figs[0],3*figs[0]) )
ax = plt.subplot(1,2,1)
plt.imshow(im, 'gray', interpolation='none')
plt.contour(grd,cols=('r'))
ax.set_xlabel('ground truth')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(1,2,2)
plt.imshow(seg, 'jet', interpolation='none')
ax.set_xlabel ('pred')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
#ax = plt.subplot(1,3,3)
plt.figure(figsize=(figs[0]/3,figs[0]/3) )
plt.imshow(im, 'gray', interpolation='none')
plt.imshow(masked, 'autumn', interpolation='none', alpha=0.2)
plt.xlabel ('pred')
#plt.xaxis().set_visible(False)
#plt.yaxis().set_visible(False)
plt.show()
def compare_predWcontour(im,grd,seg,figs,thresh=0.2):
mask = seg>thresh
masked = np.ma.masked_where(mask == 0, mask)
plt.figure(figsize=(figs[0],2*figs[0]) )
ax = plt.subplot(1,2,1)
plt.imshow(im, 'gray', interpolation='none')
plt.imshow(masked, 'autumn', interpolation='none', alpha=0.2)
plt.contour(grd,cols=('r'))
ax.set_xlabel('ground truth (contour)')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(1,2,2)
plt.imshow(seg, 'viridis', interpolation='none')
ax.set_xlabel ('pred')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
def plot2(img1,img2,figs=(30,60)):
plt.figure(figsize=figs)
ax = plt.subplot(1,2,1)
plt.imshow(img1, 'gray', interpolation='none')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(1,2,2)
plt.imshow(img2, 'gray', interpolation='none')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
def plotData2(data1,data2,n=5,Figsize=(30,60),Sride=5):
print('plot train data')
stride=Sride
offset = 0
pcounter=1
for i in range(offset, offset+n*stride, stride):
pcounter += 1
img1 = np.squeeze(data[i+offset,:,:,:])
img2 = np.squeeze(target[i+offset,:,:,:])
plot2(img1,img2,figs=Figsize)
def plot_some_Data(data,target,n=5,Figsize=(20,40),Sride=5):
print('plot train data')
stride=Sride
offset = 1
pcounter=1
for i in range(offset, offset+n*stride, stride):
pcounter += 1
im = np.squeeze(data[i+offset,:,:,:])
seg = np.squeeze(target[i+offset,:,:,:])
Overlay(im,seg,figs=Figsize)
#%%
def get_Test_Data(Data_Path):
os.chdir(Data_Path)
data = np.load('data_test.npy')
target = np.load('target_test.npy')
patient_id = np.load('patient_test.npy')
print_info(data,target,'all')
_,img_rows,img_cols,_ = data.shape
class_freq = {0:1,1: np.sum(target/255.)/(data.shape[0]*data.shape[1]*data.shape[2] ) }
print 'class_freq {0}'.format( class_freq )
return data,target,patient_id,img_rows,img_cols,class_freq
def get_Data(Data_Path):
os.chdir(Data_Path)
datainfofile = open('DataDescription.txt','r')
data_info = datainfofile.read()
datainfofile.close()
data = np.load('data_train.npy')
target = np.load('target_train.npy')
patient_id = np.load('patient_train.npy')
print_info(data,target,'all')
_,img_rows,img_cols,_ = data.shape
class_freq = {0:1,1: np.sum(target/255.)/(data.shape[0]*data.shape[1]*data.shape[2] ) }
print 'class_freq {0}'.format( class_freq )
return data,target,patient_id,img_rows,img_cols,class_freq,data_info
def get_test_Data(Data_Path):
os.chdir(Data_Path)
datainfofile = open('DataDescription.txt','r')
data_info = datainfofile.read()
datainfofile.close()
data = np.load('data_test.npy')
target = np.load('target_test.npy')
patient_id = np.load('patient_test.npy')
print_info(data,target,'all')
_,img_rows,img_cols,_ = data.shape
class_freq = {0:1,1: np.sum(target/255.)/(data.shape[0]*data.shape[1]*data.shape[2] ) }
print 'class_freq {0}'.format( class_freq )
return data,target,patient_id,img_rows,img_cols,class_freq
def get_data(Data_Path):
os.chdir(Data_Path)
datainfofile = open('DataDescription.txt','r')
data_info = datainfofile.read()
datainfofile.close()
data_train = np.load('data_train.npy')
target_train = np.load('target_train.npy')
patient_train = np.load('patient_train.npy')
#print_info(data_train,target_train,'train')
print('-----------')
data_valid = np.load('data_val.npy')
target_valid = np.load('target_val.npy')
patient_valid = np.load('patient_val.npy')
#print_info(data_valid,target_valid,'validation')
print('-----------')
data,target,patient_id = recomb_data(data_train,target_train,patient_train,data_valid,target_valid,patient_valid)
print_info(data,target,'all')
_,img_rows,img_cols,_ = data.shape
class_freq = {0:1,1: np.sum(target/255.)/(data.shape[0]*data.shape[1]*data.shape[2] ) }
print 'class_freq {0}'.format( class_freq )
return data,target,patient_id,img_rows,img_cols,class_freq
#%%
def Get_BRATS(path_brats=''):
if not(path_brats):
path_brats = '/home/amourav/Brain Data/Processed for UNET/Brats 512/'
LGG_data = np.load(path_brats+'LGG_data.npy')
LGG_target = np.load(path_brats+'LGG_target.npy')
HGG_data = np.load(path_brats+'HGG_data.npy')
HGG_target = np.load(path_brats+'HGG_target.npy')
return np.vstack( (HGG_data,LGG_data) ) , np.vstack( (HGG_target,LGG_target) )
#%%
def CropArr(np_im,shape):
""" Takes an image and pads or crops to fit standard volume
-----------
np_im: ndarr
shape: [nslices,nrows,ncols]
Returns
-----------
np_padded : new numpy array
"""
np_padded = np.zeros(shape, dtype = np_im.dtype)
old_shape = np_im.shape
pad_llims = np.zeros([3],dtype=np.uint16)
pad_ulims = np.zeros([3],dtype=np.uint16)
old_llims = np.zeros([3],dtype=np.uint16)
old_ulims = np.zeros([3],dtype=np.uint16)
D = len(np_im.shape)
for i in range(D):
if shape[i] < old_shape[i]: # need to crop input image
pad_llims[i] = 0
pad_ulims[i] = shape[i]
crop = int((old_shape[i]-shape[i])/2)
old_llims[i]=crop
old_ulims[i]=crop + shape[i]
elif shape[i] == old_shape[i]: # need to crop input image
pad_llims[i] = 0
pad_ulims[i] = shape[i]
old_llims[i]=0
old_ulims[i]=shape[i]
else:
old_llims[i] = 0
old_ulims[i] = old_shape[i]
pad = int((shape[i]-old_shape[i])/2)
pad_llims[i]=pad
pad_ulims[i]= pad + old_shape[i]
np_padded[pad_llims[0]:pad_ulims[0],
pad_llims[1]:pad_ulims[1]] = \
np_im[old_llims[0]:old_ulims[0],
old_llims[1]:old_ulims[1]]
return np_padded
def augment(img,msk,Type, random_state=None,verbose=False):
# RANGE FOR PARAMETERS OF DATA AUGMENTATION
if Type=='conservative':
FlipH = np.random.randint(2,size=1)
FlipV = np.random.randint(1,size=1)
SHFT_row, SHFT_col = np.random.randint(-60,10), np.random.randint(-60,60)
theta = np.random.randint(-10,10)
zoom = .9 +0.2*np.random.rand()
shear_angle = 0#-20 + 40*np.random.rand()
gam = .8 +0.4*np.random.rand()
alpha, sigma = 0,0#720,24 # Elastic Deformation params
elif Type=='liberal':
FlipH = np.random.randint(2,size=1)
FlipV = np.random.randint(2,size=1)
SHFT_row, SHFT_col = np.random.randint(-90,10), np.random.randint(-80,80)
theta = np.random.randint(-20,20)
zoom = .9 +0.2*np.random.rand()
shear_angle = -20 + 40*np.random.rand()
gam = .8 +0.4*np.random.rand()
alpha, sigma = 720,24 # Elastic Deformation params
if random_state is None:
random_state = np.random.RandomState(None)
stdr_size = img.shape
if verbose:
print 'shift: ', SHFT_row, SHFT_col, '\r\n', 'theta {0}'.format(theta), '\r\n' ,'zoom {0}'.format(zoom), \
'\r\n', 'shear {0}'.format(shear_angle), '\r\n','gamma {0}'.format( gam), \
'\r\n','alh {0}, sig {1}'.format(alpha, sigma)
## APPLY TRANSFORMATIONS
#Flip H
if FlipH:
img_FlipH=img[:,::-1]
msk_FlipH=msk[:,::-1]
elif FlipH==0:
img_FlipH=img
msk_FlipH=msk
#Flip V
if FlipV:
img_FlipV=img_FlipH[::-1,:]
msk_FlipV=msk_FlipH[::-1,:]
elif FlipV==0:
img_FlipV=img_FlipH
msk_FlipV=msk_FlipH
#Shift
img_shft = scipy.ndimage.interpolation.shift( img_FlipV,[SHFT_row, SHFT_col] )
msk_shft = scipy.ndimage.interpolation.shift( msk_FlipV,[SHFT_row, SHFT_col] )
#Rotate
img_rot = scipy.ndimage.interpolation.rotate( img_shft,theta ,reshape =False)
msk_rot = scipy.ndimage.interpolation.rotate( msk_shft,theta ,reshape =False,order =0)
#zoom
img_zoom = scipy.ndimage.interpolation.zoom( img_rot , (zoom,zoom) )
img_C = CropArr(img_zoom , stdr_size)
msk_zoom = scipy.ndimage.interpolation.zoom( msk_rot , (zoom,zoom) ,mode='nearest',order=0)
msk_C = CropArr(msk_zoom , stdr_size)
#Shear
shftX = stdr_size[0]*np.sin( shear_angle*np.pi/180)/2
shear_tf = tr.AffineTransform(translation=[shftX,0] ,shear = shear_angle*np.pi/180 )
img_shear = (255*tr.warp(img_C, shear_tf) ).astype(np.uint8)
msk_shear = (255*(tr.warp(msk_C, shear_tf , order=3)>.5) ).astype(np.uint8)
#gamma /
img_G = (255*( (img_shear/255.)**gam ) ).astype(np.uint8)
#Elastic Deformation
dx = gaussian_filter((random_state.rand(*stdr_size) * 2 - 1), sigma, mode="constant", cval=0) * alpha
dy = gaussian_filter((random_state.rand(*stdr_size) * 2 - 1), sigma, mode="constant", cval=0) * alpha
x, y = np.meshgrid(np.arange(stdr_size[0]), np.arange(stdr_size[1]), indexing='ij')
indices = np.reshape(x+dx, (-1, 1)), np.reshape(y+dy, (-1, 1))
img_e = map_coordinates(img_G, indices, order=1).reshape(stdr_size)
msk_e = map_coordinates(msk_shear, indices, order=0).reshape(stdr_size)
return img_e, msk_e
def aug_liberal(img,msk, random_state=None,verbose=False):
# RANGE FOR PARAMETERS OF DATA AUGMENTATION
FlipH = np.random.randint(2,size=1)
FlipV = np.random.randint(2,size=1)
SHFT_row, SHFT_col = np.random.randint(-90,10), np.random.randint(-80,80)
theta = np.random.randint(-20,20)
zoom = .9 +0.2*np.random.rand()
shear_angle = -20 + 40*np.random.rand()
gam = .8 +0.4*np.random.rand()
alpha, sigma = 720,24 # Elastic Deformation params
if random_state is None:
random_state = np.random.RandomState(None)
stdr_size = img.shape
if verbose:
print 'shift: ', SHFT_row, SHFT_col, '\r\n', 'theta {0}'.format(theta), '\r\n' ,'zoom {0}'.format(zoom), \
'\r\n', 'shear {0}'.format(shear_angle), '\r\n','gamma {0}'.format( gam), \
'\r\n','alh {0}, sig {1}'.format(alpha, sigma)
## APPLY TRANSFORMATIONS
#Flip H
if FlipH:
img_FlipH=img[:,::-1]
msk_FlipH=msk[:,::-1]
elif FlipH==0:
img_FlipH=img
msk_FlipH=msk
#Flip V
if FlipV:
img_FlipV=img_FlipH[::-1,:]
msk_FlipV=msk_FlipH[::-1,:]
elif FlipV==0:
img_FlipV=img_FlipH
msk_FlipV=msk_FlipH
#Shift
img_shft = scipy.ndimage.interpolation.shift( img_FlipV,[SHFT_row, SHFT_col] )
msk_shft = scipy.ndimage.interpolation.shift( msk_FlipV,[SHFT_row, SHFT_col] )
#Rotate
img_rot = scipy.ndimage.interpolation.rotate( img_shft,theta ,reshape =False)
msk_rot = scipy.ndimage.interpolation.rotate( msk_shft,theta ,reshape =False,order =0)
#zoom
img_zoom = scipy.ndimage.interpolation.zoom( img_rot , (zoom,zoom) )
img_C = CropArr(img_zoom , stdr_size)
msk_zoom = scipy.ndimage.interpolation.zoom( msk_rot , (zoom,zoom) ,mode='nearest',order=0)
msk_C = CropArr(msk_zoom , stdr_size)
#Shear
shftX = stdr_size[0]*np.sin( shear_angle*np.pi/180)/2
shear_tf = tr.AffineTransform(translation=[shftX,0] ,shear = shear_angle*np.pi/180 )
img_shear = (255*tr.warp(img_C, shear_tf) ).astype(np.uint8)
msk_shear = (255*(tr.warp(msk_C, shear_tf , order=3)>.5) ).astype(np.uint8)
#gamma /
img_G = (255*( (img_shear/255.)**gam ) ).astype(np.uint8)
#Elastic Deformation
dx = gaussian_filter((random_state.rand(*stdr_size) * 2 - 1), sigma, mode="constant", cval=0) * alpha
dy = gaussian_filter((random_state.rand(*stdr_size) * 2 - 1), sigma, mode="constant", cval=0) * alpha
x, y = np.meshgrid(np.arange(stdr_size[0]), np.arange(stdr_size[1]), indexing='ij')
indices = np.reshape(x+dx, (-1, 1)), np.reshape(y+dy, (-1, 1))
img_e = map_coordinates(img_G, indices, order=1).reshape(stdr_size)
msk_e = map_coordinates(msk_shear, indices, order=0).reshape(stdr_size)
return img_e, msk_e
def make_AUG(AUG_DATA,Type,data,target,patient_id,N_BATCHES=3):
print N_BATCHES
os.mkdir(AUG_DATA)
print 'Generating Augmentations...'
np.save(AUG_DATA+os.sep+'aug_patientID.npy', patient_id)
for j in range(N_BATCHES):
data_aug, target_aug = np.zeros_like(data), np.zeros_like(target)
print 'batch {0} / {1}'.format(j+1,N_BATCHES)
for i,(IMG,SEG) in enumerate(zip(data,target)):
img_aug, msk_aug = augment(IMG[:,:,0],SEG[:,:,0],Type)
data_aug[i,:,:,0], target_aug[i,:,:,0] = img_aug, msk_aug
if i%300==0: print '{0}/{1}'.format(i,data.shape[0])
np.save(AUG_DATA+os.sep+'aug_data_{1}_{0}.npy'.format(j,Type),data_aug )
np.save(AUG_DATA+os.sep+'aug_target_{1}_{0}.npy'.format(j,Type),target_aug )
# np.save(AUG_DATA+os.sep+'aug_patientID.npy', patient_id)
def get_AUG_data(AUG_DATA,N_batches,Type):
AUGX_list, AUGY_list = [],[]
for j in range(N_batches):
data_aug=np.load(AUG_DATA+os.sep+'aug_data_{1}_{0}.npy'.format(j,Type), )
target_aug=np.load(AUG_DATA+os.sep+'aug_target_{1}_{0}.npy'.format(j,Type) )
AUGX_list.append(data_aug)
AUGY_list.append(target_aug)
patient_id_aug=np.load(AUG_DATA+os.sep+'aug_patientID.npy')
return AUGX_list,AUGY_list,patient_id_aug
def get_AUG(AUG_DATA,data,target,patient_id,Type,N_batches=3):
if os.path.isdir(AUG_DATA):
AUGX_list,AUGY_list,patient_id_aug=get_AUG_data(AUG_DATA,N_batches,Type)
else:
make_AUG(AUG_DATA,Type,data,target,patient_id)
AUGX_list,AUGY_list,patient_id_aug = get_AUG_data(AUG_DATA,Type,N_batches)
return AUGX_list,AUGY_list,patient_id_aug
def combine_augdata(data,target,patient_id,AUGX_list,AUGY_list,patient_id_aug):
data_plus,target_plus,patient_id_plus = data,target,patient_id
for i,(augX,augY) in enumerate(zip(AUGX_list,AUGY_list)):
data_plus,target_plus=np.vstack( (data_plus,augX) ), np.vstack( (target_plus,augY) )
patient_id_plus=np.concatenate( (patient_id_plus,patient_id_aug) )
return data_plus,target_plus,patient_id_plus
def AUGMENT_DATA(AUG_DATA,TYPE,data,target,patient_id,N_Batches=1):
AUGX_list,AUGY_list,patient_id_aug=get_AUG(AUG_DATA,data,target,patient_id,TYPE,N_batches=N_Batches)
data_plus,target_plus,patient_id_plus = combine_augdata(data,target,patient_id,AUGX_list,AUGY_list,patient_id_aug)
return data_plus,target_plus ,patient_id_plus
def shw(im,Cmap='gray',Figs=(5,5)):
plt.figure(figsize = Figs)
plt.imshow(im,cmap=Cmap)
plt.show()
def set_gpu_limit(lim):
#os.environ["TENSORFLOW_FLAGS"] = "device=gpu6"
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = lim
set_session(tf.Session(config=config))
def write_eval(IMGS,MSKS,model,i,patients):
print '\n validation scores: \n'
scores = model.evaluate(IMGS,MSKS, batch_size=1, verbose=0)
txtfile = open('model_eval{0}.txt'.format(i),'w')
txtfile.write('MODEL EVAL - Patients {0} \r\n'.format(patients))
for score,metric in zip(scores,model.metrics_names):
print '{0} score: {1}'.format(metric,score)
txtfile.write('{0} score: {1} \r\n'.format(metric,score))
txtfile.close()
def save_unet(model,i):
i=str(i)
pkl(model.history.history,'trainning_hist_{0}'.format(i) )
weights = model.get_weights()
np.save('model_weights_{0}.npy'.format(i),weights)
json_string = model.to_json()
text_file = open("model_json_{0}.txt".format(i), "w")
text_file.write(json_string)
text_file.close()
history = model.history.history
history_keys = history.keys()
test_file2 = open("model_hist_{0}.txt".format(i), "w")
for key in history_keys:
test_file2.write('{0}'.format(history[key]) )
test_file2.write('\n' )
test_file2.write('\n' )
test_file2.close()
def avg_loss(model_hist):
val_loss_avg = []
val_loss_std = []
for hist in model_hist:
temp =hist['val_loss']
val_avg,val_std = np.mean(temp[-10:-1] ) , np.std(temp[-10:-1] )
print val_avg
val_loss_avg.append( val_avg )
print '\r\n final avg {0} - std {1} - se {2}'.format( np.mean(val_loss_avg), np.std(val_loss_avg ),np.std(val_loss_avg )/(len(model_hist)-1 ) )
class CUSTOMCALL2(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.losses = []
self.val_losses = []
self.counter1 = 0
def on_epoch_end(self, epoch, logs={}):
self.losses.append(logs.get('loss'))
self.counter1 += 1
print 'epoch {0}'.format(self.counter1) ,'loss {0}'.format( logs.get('loss') )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment