Last active
          January 16, 2019 22:26 
        
      - 
      
- 
        Save karolzak/4d72d589d518bca7fe79ec3cd6cd7dd7 to your computer and use it in GitHub Desktop. 
    Simple UNET implementation in Keras
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | ###################### unet ############################ | |
| from keras.models import Model | |
| from keras.layers import * | |
| def upsample_conv(filters, kernel_size, strides, padding): | |
| return Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding) | |
| def upsample_simple(filters, kernel_size, strides, padding): | |
| return UpSampling2D(strides) | |
| def conv2d_block( | |
| inputs, | |
| use_batch_norm=True, | |
| use_dropout=True, | |
| dropout=0.5, | |
| filters=16, | |
| kernel_size=(3, 3), | |
| activation='relu', | |
| kernel_initializer='he_normal', | |
| padding='same'): | |
| c = Conv2D(filters, kernel_size, activation=activation, kernel_initializer=kernel_initializer, padding=padding) (inputs) | |
| if use_batch_norm: | |
| c = BatchNormalization()(c) | |
| if use_dropout: | |
| c = Dropout(dropout)(c) | |
| c = Conv2D(filters, kernel_size, activation=activation, kernel_initializer=kernel_initializer, padding=padding) (c) | |
| if use_batch_norm: | |
| c = BatchNormalization()(c) | |
| return c | |
| def get_unet( | |
| input_shape, | |
| use_batch_norm=True, | |
| upsample_mode='DECONV', | |
| use_dropout=True, | |
| use_dropout_on_upsampling=False, | |
| dropout=0.1, | |
| dropout_change_per_layer=0.0, | |
| filters=16, | |
| nm_layers=4): | |
| if upsample_mode=='DECONV': | |
| upsample=upsample_conv | |
| else: | |
| upsample=upsample_simple | |
| # Build U-Net model | |
| inputs = Input(input_shape) | |
| # it can be done iteratively but im leaving it like that for better clarity and transparency | |
| #### downsampling layers | |
| c1 = conv2d_block(inputs=inputs, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout) | |
| p1 = MaxPooling2D((2, 2)) (c1) | |
| # start increasing number of filters with each layer and change dropout if required | |
| dropout += dropout_change_per_layer | |
| filters = filters*2 | |
| c2 = conv2d_block(inputs=p1, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout) | |
| p2 = MaxPooling2D((2, 2)) (c2) | |
| # increasing number of filters with each layer and change dropout if required | |
| dropout += dropout_change_per_layer | |
| filters = filters*2 | |
| c3 = conv2d_block(inputs=p2, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout) | |
| p3 = MaxPooling2D((2, 2)) (c3) | |
| # increasing number of filters with each layer and change dropout if required | |
| dropout += dropout_change_per_layer | |
| filters = filters*2 | |
| c4 = conv2d_block(inputs=p3, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout) | |
| p4 = MaxPooling2D((2, 2)) (c4) | |
| # increasing number of filters with each layer and change dropout if required | |
| dropout += dropout_change_per_layer | |
| filters = filters*2 | |
| # no max pooling on that last layer | |
| c5 = conv2d_block(inputs=p4, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout) | |
| #### upsampling layers | |
| use_dropout=use_dropout_on_upsampling # by default we dont want dropout on upsampling layers (same as in the original implementation of research paper) | |
| filters //= 2 # decreasing number of filters with each layer | |
| u6 = upsample(filters, (2, 2), strides=(2, 2), padding='same') (c5) | |
| u6 = concatenate([u6, c4]) | |
| c6 = conv2d_block(inputs=u6, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout) | |
| filters //= 2 # decreasing number of filters with each layer | |
| u7 = upsample(filters, (2, 2), strides=(2, 2), padding='same') (c6) | |
| u7 = concatenate([u7, c3]) | |
| c7 = conv2d_block(inputs=u7, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout) | |
| filters //= 2 # decreasing number of filters with each layer | |
| u8 = upsample(filters, (2, 2), strides=(2, 2), padding='same') (c7) | |
| u8 = concatenate([u8, c2]) | |
| c8 = conv2d_block(inputs=u8, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout) | |
| filters //= 2 # decreasing number of filters with each layer | |
| u9 = upsample(filters, (2, 2), strides=(2, 2), padding='same') (c8) | |
| u9 = concatenate([u9, c1]) | |
| c9 = conv2d_block(inputs=u9, filters=filters, use_batch_norm=use_batch_norm, use_dropout=use_dropout, dropout=dropout) | |
| outputs = Conv2D(1, (1, 1), activation='sigmoid') (c9) | |
| model = Model(inputs=[inputs], outputs=[outputs]) | |
| return model | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment