Skip to content

Instantly share code, notes, and snippets.

@atamborrino
Created October 11, 2018 12:02
Show Gist options
  • Save atamborrino/6d6ce9748dd442b560fa60e739dad8dd to your computer and use it in GitHub Desktop.
Save atamborrino/6d6ce9748dd442b560fa60e739dad8dd to your computer and use it in GitHub Desktop.
U-net
from tensorflow import keras
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.models import load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, concatenate, Dropout,BatchNormalization
from tensorflow.keras.layers import Conv2D, Concatenate, MaxPooling2D
from tensorflow.keras.layers import UpSampling2D, Dropout, BatchNormalization
def conv_block(m, dim, acti, bn, res, do=0):
n = Conv2D(dim, 3, activation=acti, padding='same')(m)
n = BatchNormalization()(n) if bn else n
n = Dropout(do)(n) if do else n
n = Conv2D(dim, 3, activation=acti, padding='same')(n)
n = BatchNormalization()(n) if bn else n
return Concatenate()([m, n]) if res else n
def level_block(m, dim, depth, inc, acti, do, bn, mp, up, res):
if depth > 0:
n = conv_block(m, dim, acti, bn, res)
m = MaxPooling2D()(n) if mp else Conv2D(dim, 3, strides=2, padding='same')(n)
m = level_block(m, int(inc*dim), depth-1, inc, acti, do, bn, mp, up, res)
if up:
m = UpSampling2D()(m)
m = Conv2D(dim, 2, activation=acti, padding='same')(m)
else:
m = Conv2DTranspose(dim, 3, strides=2, activation=acti, padding='same')(m)
n = Concatenate()([n, m])
m = conv_block(n, dim, acti, bn, res)
else:
m = conv_block(m, dim, acti, bn, res, do)
return m
def UNet(img_shape, out_ch=1, start_ch=64, depth=4, inc_rate=2., activation='relu',
dropout=0.5, batchnorm=False, maxpool=True, upconv=True, residual=False):
i = Input(shape=img_shape)
o = level_block(i, start_ch, depth, inc_rate, activation, dropout, batchnorm, maxpool, upconv, residual)
o = Conv2D(out_ch, 1, activation='sigmoid')(o)
return Model(inputs=i, outputs=o)
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras import backend as K
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred = K.cast(y_pred, 'float32')
y_pred_f = K.cast(K.greater(K.flatten(y_pred), 0.5), 'float32')
intersection = y_true_f * y_pred_f
score = 2. * K.sum(intersection) / (K.sum(y_true_f) + K.sum(y_pred_f))
return score
def dice_loss(y_true, y_pred):
smooth = 1.
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = y_true_f * y_pred_f
score = (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
return 1. - score
def bce_dice_loss(y_true, y_pred):
return binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
model = UNet((img_size_target,img_size_target,1),start_ch=16,depth=5,batchnorm=True)
model.compile(loss=bce_dice_loss, optimizer="adam")
model.summary()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment