Created
February 3, 2017 18:35
-
-
Save galtay/4565f0c100adca913fe2570f821e4331 to your computer and use it in GitHub Desktop.
unet implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def unet_model(batch_size, npix_in, n_channels, n_filters, n_classes, activation='relu'): | |
input_layer = Input(batch_shape=(batch_size, npix_in, npix_in, n_channels), name='input') | |
dblk1_conv1 = Convolution2D(n_filters, 3, 3, activation=activation, name='dblk1_conv1')(input_layer) | |
dblk1_conv2 = Convolution2D(n_filters, 3, 3, activation=activation, name='dblk1_conv2')(dblk1_conv1) | |
dblk1_pool = MaxPooling2D(pool_size=(2,2), name='dblk1_pool')(dblk1_conv2) | |
dblk2_conv1 = Convolution2D(n_filters * 2, 3, 3, activation=activation, name='dblk2_conv1')(dblk1_pool) | |
dblk2_conv2 = Convolution2D(n_filters * 2, 3, 3, activation=activation, name='dblk2_conv2')(dblk2_conv1) | |
dblk2_pool = MaxPooling2D(pool_size=(2,2), name='dblk2_pool')(dblk2_conv2) | |
dblk3_conv1 = Convolution2D(n_filters * 4, 3, 3, activation=activation, name='dblk3_conv1')(dblk2_pool) | |
dblk3_conv2 = Convolution2D(n_filters * 4, 3, 3, activation=activation, name='dblk3_conv2')(dblk3_conv1) | |
dblk3_pool = MaxPooling2D(pool_size=(2,2), name='dblk3_pool')(dblk3_conv2) | |
dblk4_conv1 = Convolution2D(n_filters * 8, 3, 3, activation=activation, name='dblk4_conv1')(dblk3_pool) | |
dblk4_conv2 = Convolution2D(n_filters * 8, 3, 3, activation=activation, name='dblk4_conv2')(dblk4_conv1) | |
dblk4_drop = Dropout(0.5, name='dblk4_drop')(dblk4_conv2) | |
dblk4_pool = MaxPooling2D(pool_size=(2,2), name='dblk4_pool')(dblk4_drop) | |
bottom_conv1 = Convolution2D(n_filters * 16, 3, 3, activation=activation, name='bottom_conv1')(dblk4_pool) | |
bottom_conv2 = Convolution2D(n_filters * 16, 3, 3, activation=activation, name='bottom_conv2')(bottom_conv1) | |
bottom_drop = Dropout(0.5, name='bottom_drop')(bottom_conv2) | |
outpix = bottom_drop.get_shape()[1].value * 2 | |
diff = dblk4_drop.get_shape()[1].value - outpix | |
cpix = diff//2 | |
ublk4_deconv = Deconvolution2D( | |
n_filters * 8, 2, 2, output_shape=(batch_size, outpix, outpix, n_filters * 8), | |
subsample=(2,2), activation=activation, name='ublk4_deconv')(bottom_drop) | |
ublk4_crop = Cropping2D(cropping=((cpix,cpix),(cpix,cpix)), name='ublk4_crop')(dblk4_drop) | |
ublk4_concat = merge([ublk4_crop, ublk4_deconv], mode='concat', concat_axis=3, name='ublk4_concat') | |
ublk4_conv1 = Convolution2D(n_filters * 8, 3, 3, activation=activation, name='ublk4_conv1')(ublk4_concat) | |
ublk4_conv2 = Convolution2D(n_filters * 8, 3, 3, activation=activation, name='ublk4_conv2')(ublk4_conv1) | |
outpix = ublk4_conv2.get_shape()[1].value * 2 | |
diff = dblk3_conv2.get_shape()[1].value - outpix | |
cpix = diff//2 | |
ublk3_deconv = Deconvolution2D( | |
n_filters * 4, 2, 2, output_shape=(batch_size, outpix, outpix, n_filters * 4), | |
subsample=(2,2), activation=activation, name='ublk3_deconv')(ublk4_conv2) | |
ublk3_crop = Cropping2D(cropping=((cpix,cpix),(cpix,cpix)), name='ublk3_crop')(dblk3_conv2) | |
ublk3_concat = merge([ublk3_crop, ublk3_deconv], mode='concat', concat_axis=3, name='ublk3_concat') | |
ublk3_conv1 = Convolution2D(n_filters * 4, 3, 3, activation=activation, name='ublk3_conv1')(ublk3_concat) | |
ublk3_conv2 = Convolution2D(n_filters * 4, 3, 3, activation=activation, name='ublk3_conv2')(ublk3_conv1) | |
outpix = ublk3_conv2.get_shape()[1].value * 2 | |
diff = dblk2_conv2.get_shape()[1].value - outpix | |
cpix = diff//2 | |
ublk2_deconv = Deconvolution2D( | |
n_filters * 2, 2, 2, output_shape=(batch_size, outpix, outpix, n_filters * 2), | |
subsample=(2,2), activation=activation, name='ublk2_deconv')(ublk3_conv2) | |
ublk2_crop = Cropping2D(cropping=((cpix,cpix),(cpix,cpix)), name='ublk2_crop')(dblk2_conv2) | |
ublk2_concat = merge([ublk2_crop, ublk2_deconv], mode='concat', concat_axis=3, name='ublk2_concat') | |
ublk2_conv1 = Convolution2D(n_filters * 2, 3, 3, activation=activation, name='ublk2_conv1')(ublk2_concat) | |
ublk2_conv2 = Convolution2D(n_filters * 2, 3, 3, activation=activation, name='ublk2_conv2')(ublk2_conv1) | |
outpix = ublk2_conv2.get_shape()[1].value * 2 | |
diff = dblk1_conv2.get_shape()[1].value - outpix | |
cpix = diff//2 | |
ublk1_deconv = Deconvolution2D( | |
n_filters, 2, 2, output_shape=(batch_size, outpix, outpix, n_filters), | |
subsample=(2,2), activation=activation, name='ublk1_deconv')(ublk2_conv2) | |
ublk1_crop = Cropping2D(cropping=((cpix,cpix),(cpix,cpix)), name='ublk1_crop')(dblk1_conv2) | |
ublk1_concat = merge([ublk1_crop, ublk1_deconv], mode='concat', concat_axis=3, name='ublk1_concat') | |
ublk1_conv1 = Convolution2D(n_filters, 3, 3, activation=activation, name='ublk1_conv1')(ublk1_concat) | |
ublk1_conv2 = Convolution2D(n_filters, 3, 3, activation=activation, name='ublk1_conv2')(ublk1_conv1) | |
output_layer1 = Convolution2D(n_classes, 1, 1, name='logits')(ublk1_conv2) | |
outpix = output_layer1.get_shape()[1].value | |
shape = (outpix * outpix, n_classes) | |
output_layer2 = Reshape(shape)(output_layer1) | |
output_layer3 = Activation('sigmoid', name='sigmoid')(output_layer2) | |
output = output_layer3 | |
model = Model(input=input_layer, output=output) | |
return model |
Author
galtay
commented
Feb 3, 2017
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment