Skip to content

Instantly share code, notes, and snippets.

@karolzak
Last active January 16, 2019 22:26
Show Gist options
  • Save karolzak/4d72d589d518bca7fe79ec3cd6cd7dd7 to your computer and use it in GitHub Desktop.
Save karolzak/4d72d589d518bca7fe79ec3cd6cd7dd7 to your computer and use it in GitHub Desktop.
Simple UNET implementation in Keras
###################### unet ############################
from keras.models import Model
from keras.layers import *
def upsample_conv(filters, kernel_size, strides, padding):
return Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)
def upsample_simple(filters, kernel_size, strides, padding):
return UpSampling2D(strides)
def conv2d_block(
inputs,
use_batch_norm=True,
use_dropout=True,
dropout=0.5,
filters=16,
kernel_size=(3, 3),
activation='relu',
kernel_initializer='he_normal',
padding='same'):
c = Conv2D(filters, kernel_size, activation=activation, kernel_initializer=kernel_initializer, padding=padding) (inputs)
if use_batch_norm:
c = BatchNormalization()(c)
if use_dropout:
c = Dropout(dropout)(c)
c = Conv2D(filters, kernel_size, activation=activation, kernel_initializer=kernel_initializer, padding=padding) (c)
if use_batch_norm:
c = BatchNormalization()(c)
return c
def get_unet(
input_shape,
use_batch_norm=True,
upsample_mode='DECONV',
use_dropout=True,
use_dropout_on_upsampling=False,
dropout=0.1,
dropout_change_per_layer=0.0,
filters=16,
nm_layers=4):
if upsample_mode=='DECONV':
upsample=upsample_conv
else:
upsample=upsample_simple
# Build U-Net model
inputs = Input(input_shape)
# it can be done iteratively but im leaving it like that for better clarity and transparency
#### downsampling layers
c1 = conv2d_block(inputs=inputs, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout)
p1 = MaxPooling2D((2, 2)) (c1)
# start increasing number of filters with each layer and change dropout if required
dropout += dropout_change_per_layer
filters = filters*2
c2 = conv2d_block(inputs=p1, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout)
p2 = MaxPooling2D((2, 2)) (c2)
# increasing number of filters with each layer and change dropout if required
dropout += dropout_change_per_layer
filters = filters*2
c3 = conv2d_block(inputs=p2, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout)
p3 = MaxPooling2D((2, 2)) (c3)
# increasing number of filters with each layer and change dropout if required
dropout += dropout_change_per_layer
filters = filters*2
c4 = conv2d_block(inputs=p3, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout)
p4 = MaxPooling2D((2, 2)) (c4)
# increasing number of filters with each layer and change dropout if required
dropout += dropout_change_per_layer
filters = filters*2
# no max pooling on that last layer
c5 = conv2d_block(inputs=p4, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout)
#### upsampling layers
use_dropout=use_dropout_on_upsampling # by default we dont want dropout on upsampling layers (same as in the original implementation of research paper)
filters //= 2 # decreasing number of filters with each layer
u6 = upsample(filters, (2, 2), strides=(2, 2), padding='same') (c5)
u6 = concatenate([u6, c4])
c6 = conv2d_block(inputs=u6, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout)
filters //= 2 # decreasing number of filters with each layer
u7 = upsample(filters, (2, 2), strides=(2, 2), padding='same') (c6)
u7 = concatenate([u7, c3])
c7 = conv2d_block(inputs=u7, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout)
filters //= 2 # decreasing number of filters with each layer
u8 = upsample(filters, (2, 2), strides=(2, 2), padding='same') (c7)
u8 = concatenate([u8, c2])
c8 = conv2d_block(inputs=u8, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout)
filters //= 2 # decreasing number of filters with each layer
u9 = upsample(filters, (2, 2), strides=(2, 2), padding='same') (c8)
u9 = concatenate([u9, c1])
c9 = conv2d_block(inputs=u9, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout)
outputs = Conv2D(1, (1, 1), activation='sigmoid') (c9)
model = Model(inputs=[inputs], outputs=[outputs])
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment