Last active
August 27, 2021 14:05
-
-
Save previtus/49fb384402391bb8e7a266e985a1d451 to your computer and use it in GitHub Desktop.
Siamese version of Segmentation networks provided by various backend encoders (with weights included) by https://github.com/qubvel/segmentation_models (update to have custom trained encoder).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Train NN on some specific task (perhaps looking at some type of data) | |
# starts with ImageNet weights, but these are changed accordingly to what your task needs. | |
# later we can load these weights into the SiameseUnet model as a specific encoder (instead of using the Imagenet defaults). | |
import matplotlib, os, glob, fnmatch | |
if not('DISPLAY' in os.environ): | |
matplotlib.use("Agg") # when running over ssh on server | |
# FORCE CPU | |
#import os | |
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 | |
#os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
# from sklearn.model_selection import train_test_split # ... can be useful | |
from segmentation_models import Unet | |
from segmentation_models.backbones import get_preprocessing | |
from segmentation_models.losses import bce_jaccard_loss | |
from segmentation_models.metrics import iou_score | |
x_train, x_val, y_train, y_val = load_data( ... ) #magic! | |
# keep in mind what the y must be (for example: onehot for softmax) | |
print("x_train:", len(x_train), x_train.shape) | |
print("x_val:", len(x_val), x_val.shape) | |
print("y_train:", len(y_train), y_train.shape) | |
print("y_val:", len(y_val), y_val.shape) | |
# MODEL | |
BACKBONE = 'resnet34' | |
preprocess_input = get_preprocessing(BACKBONE) | |
# preprocess input | |
x_train = preprocess_input(x_train) | |
x_val = preprocess_input(x_val) | |
# define model | |
model = Unet(BACKBONE, encoder_weights='imagenet', classes=3, activation='softmax') | |
model.compile('Adam', loss=bce_jaccard_loss, metrics=[iou_score]) | |
history = model.fit( | |
x=x_train, | |
y=y_train, | |
batch_size=8, | |
epochs=100, | |
validation_data=(x_val, y_val), | |
) | |
#print(history.history) | |
model.save("UNet-Resnet34_trained_on_our_specific_data_.h5") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# [ Siamese Segmentation models ] | |
# | |
# Altered code from: | |
# https://github.com/qubvel/segmentation_models | |
# more specifically combined from files: | |
# - https://github.com/qubvel/segmentation_models/blob/master/segmentation_models/unet/builder.py | |
# - https://github.com/qubvel/segmentation_models/blob/master/segmentation_models/unet/blocks.py | |
# under commit https://github.com/qubvel/segmentation_models/commit/9c68d81d66e4fb856770a87b450a43bb2ae6ddba | |
from keras.layers import Conv2D | |
from keras.layers import Activation | |
from keras.models import Model | |
from segmentation_models.utils import freeze_model | |
from segmentation_models.utils import legacy_support | |
from segmentation_models.backbones import get_backbone, get_feature_layers | |
from segmentation_models.unet.blocks import Transpose2D_block | |
from segmentation_models.utils import get_layer_number, to_tuple | |
from keras.layers import Concatenate | |
from segmentation_models.unet.blocks import UpSampling2D, handle_block_names, ConvRelu | |
import keras | |
from keras.layers import Input | |
from keras.models import load_model | |
old_args_map = { | |
'freeze_encoder': 'encoder_freeze', | |
'skip_connections': 'encoder_features', | |
'upsample_rates': None, # removed | |
'input_tensor': None, # removed | |
} | |
@legacy_support(old_args_map) | |
def SiameseUnet(backbone_name='vgg16', | |
input_shape=(None, None, 3), | |
classes=1, | |
activation='sigmoid', | |
encoder_weights='imagenet', | |
encoder_freeze=False, | |
encoder_features='default', | |
decoder_block_type='upsampling', | |
decoder_filters=(256, 128, 64, 32, 16), | |
decoder_use_batchnorm=True, | |
**kwargs): | |
""" Unet_ is a fully convolution neural network for image semantic segmentation | |
Args: | |
backbone_name: name of classification model (without last dense layers) used as feature | |
extractor to build segmentation model. | |
input_shape: shape of input data/image ``(H, W, C)``, in general | |
case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be | |
able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``. | |
classes: a number of classes for output (output shape - ``(h, w, classes)``). | |
activation: name of one of ``keras.activations`` for last model layer | |
(e.g. ``sigmoid``, ``softmax``, ``linear``). | |
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). | |
encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. | |
encoder_features: a list of layer numbers or names starting from top of the model. | |
Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used | |
layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``. | |
decoder_block_type: one of blocks with following layers structure: | |
- `upsampling`: ``Upsampling2D`` -> ``Conv2D`` -> ``Conv2D`` | |
- `transpose`: ``Transpose2D`` -> ``Conv2D`` | |
decoder_filters: list of numbers of ``Conv2D`` layer filters in decoder blocks | |
decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers | |
is used. | |
Returns: | |
``keras.models.Model``: **Unet** | |
.. _Unet: | |
https://arxiv.org/pdf/1505.04597 | |
""" | |
load_weights_from = None | |
if encoder_weights is not "imagenet" and encoder_weights is not None: | |
load_weights_from = encoder_weights | |
encoder_weights = None | |
backbone = get_backbone(backbone_name, | |
input_shape=input_shape, | |
input_tensor=None, | |
weights=encoder_weights, | |
include_top=False) | |
if load_weights_from is not None: | |
model_to_load_weights_from = load_model(load_weights_from) | |
# now let's assume that this loaded model had its own "top" upsampling section trained on another task | |
# let's transplant what we can, that is the backbone encoder | |
output = model_to_load_weights_from.layers[len(backbone.layers)-1].output # remove activation and last conv layer | |
transplant = keras.models.Model(model_to_load_weights_from.input, output) | |
#transplant.summary() | |
transplant.save("transplant.h5") # hacky way | |
backbone.load_weights("transplant.h5") | |
# Check if the weights have been loaded | |
""" | |
inspect_i = 0 | |
import numpy as np | |
w1 = np.asarray(transplant.get_weights()[inspect_i]) | |
print(w1) | |
w2 = np.asarray(backbone.get_weights()[inspect_i]) | |
print(w2) | |
""" | |
print("Loaded weights into ",backbone_name,"from",load_weights_from) | |
if encoder_features == 'default': | |
encoder_features = get_feature_layers(backbone_name, n=4) | |
model = build_siamese_unet(backbone, | |
classes, | |
encoder_features, | |
decoder_filters=decoder_filters, | |
block_type=decoder_block_type, | |
activation=activation, | |
n_upsample_blocks=len(decoder_filters), | |
upsample_rates=(2, 2, 2, 2, 2), | |
use_batchnorm=decoder_use_batchnorm, | |
input_shape=input_shape) | |
# lock encoder weights for fine-tuning | |
if encoder_freeze: | |
freeze_model(backbone) | |
model.name = 'u-{}'.format(backbone_name) | |
return model | |
def Siamese_Upsample2D_block(filters, stage, kernel_size=(3,3), upsample_rate=(2,2), | |
use_batchnorm=False, skip_a=None, skip_b=None): | |
def layer(input_tensor): | |
conv_name, bn_name, relu_name, up_name = handle_block_names(stage) | |
x = UpSampling2D(size=upsample_rate, name=up_name)(input_tensor) | |
if skip_a is not None and skip_b is not None: | |
x = Concatenate()([x, skip_a, skip_b]) # siamese concatenation | |
x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm, | |
conv_name=conv_name + '1', bn_name=bn_name + '1', relu_name=relu_name + '1')(x) | |
x = ConvRelu(filters, kernel_size, use_batchnorm=use_batchnorm, | |
conv_name=conv_name + '2', bn_name=bn_name + '2', relu_name=relu_name + '2')(x) | |
return x | |
return layer | |
def build_siamese_unet(backbone, classes, skip_connection_layers, | |
decoder_filters=(256,128,64,32,16), | |
upsample_rates=(2,2,2,2,2), | |
n_upsample_blocks=5, | |
block_type='upsampling', | |
activation='sigmoid', | |
use_batchnorm=True, | |
input_shape=(None, None, 3)): | |
print("Entered build_unet with arguments:") | |
print("backbone",backbone) | |
print("---\n") | |
backbone.summary() | |
print("---\n") | |
print("classes",classes) | |
print("skip_connection_layers",skip_connection_layers) | |
print("decoder_filters",decoder_filters) | |
print("upsample_rates",upsample_rates) | |
print("n_upsample_blocks",n_upsample_blocks) | |
print("block_type",block_type) | |
print("activation",activation) | |
print("use_batchnorm",use_batchnorm) | |
input = backbone.input | |
x = backbone.output | |
# Prepare for multiple heads in siamese nn: | |
skip_connection_idx = ([get_layer_number(backbone, l) if isinstance(l, str) else l | |
for l in skip_connection_layers]) | |
print("skip_connection_idx", skip_connection_idx) | |
skip_connections = [] | |
for idx in skip_connection_idx: | |
skip_connection = backbone.layers[idx].output | |
skip_connections.append(skip_connection) | |
print("skip_connections layers", len(skip_connections), skip_connections) | |
#4 layers | |
# 'stage4_unit1_relu1/Relu:0' shape=(?, 16, 16, 256) | |
# 'stage3_unit1_relu1/Relu:0' shape=(?, 32, 32, 128) | |
# 'stage2_unit1_relu1/Relu:0' shape=(?, 64, 64, 64) | |
# 'relu0/Relu:0' shape=(?, 128, 128, 64) | |
siamese_backbone_model_encode = Model(inputs=[input], outputs=[x]+skip_connections) | |
print("siamese_model_encode.input", siamese_backbone_model_encode.input) | |
print("siamese_model_encode.output", siamese_backbone_model_encode.output) # x and the (now 4) skip connections | |
# Then merging | |
input_a = Input(shape=(input_shape[0], input_shape[1], input_shape[2])) | |
input_b = Input(shape=(input_shape[0], input_shape[1], input_shape[2])) | |
branch_a_outputs = siamese_backbone_model_encode([input_a]) | |
branch_b_outputs = siamese_backbone_model_encode([input_b]) | |
branch_a = branch_a_outputs[0] | |
branch_b = branch_b_outputs[0] | |
x = Concatenate()([branch_a, branch_b]) # both inputs, in theory 8x8x512 + 8x8x512 -> 8x8x1024 | |
skip_connection_outputs_a = branch_a_outputs[1:] | |
skip_connection_outputs_b = branch_b_outputs[1:] | |
if block_type == 'transpose': | |
up_block = Transpose2D_block | |
assert False # NOT IMPLEMENTED | |
else: | |
up_block = Siamese_Upsample2D_block | |
for i in range(n_upsample_blocks): | |
skip_connection_a = None | |
skip_connection_b = None | |
if i < len(skip_connection_idx): # also till len(skip_connection_outputs_a) | |
skip_connection_a = skip_connection_outputs_a[i] | |
skip_connection_b = skip_connection_outputs_b[i] | |
upsample_rate = to_tuple(upsample_rates[i]) | |
x = up_block(decoder_filters[i], i, upsample_rate=upsample_rate, | |
skip_a=skip_connection_a, skip_b=skip_connection_b, use_batchnorm=use_batchnorm)(x) | |
x = Conv2D(classes, (3,3), padding='same', name='final_conv')(x) | |
x = Activation(activation, name=activation)(x) | |
#model = Model(input, x) | |
full_model = Model(inputs=[input_a, input_b], outputs=x) | |
return full_model | |
# There is support for all of these (with weights from ImageNet included) ... qubvel/segmentation_models is awesome! | |
# VGG 'vgg16' 'vgg19' | |
# ResNet 'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152' | |
# SE-ResNet 'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152' | |
# ResNeXt 'resnext50' 'resnet101' | |
# SE-ResNeXt 'seresnext50' 'seresnet101' | |
# SENet154 'senet154' | |
# DenseNet 'densenet121' 'densenet169' 'densenet201' | |
# Inception 'inceptionv3' 'inceptionresnetv2' | |
# MobileNet 'mobilenet' 'mobilenetv2' | |
# Performance comparison for classification: https://github.com/qubvel/classification_models | |
BACKBONE = 'resnet34' | |
custom_weights_file = 'imagenet' | |
custom_weights_file = "UNet-Resnet34_trained_on_our_specific_data_.h5" # or None or "imagenet" | |
model = SiameseUnet(BACKBONE, encoder_weights=custom_weights_file, classes=3, activation='softmax', input_shape=(256, 256, 3)) | |
print("model.input", model.input) | |
print("model.output", model.output) | |
model.summary() | |
# Ps: there is posibility to change the code of additional models in similar manner to get FPN, Linknet and PSPNet | |
# Ps2: some of these Siamese NN models end up with large amount of parameters ... | |
# if we don't have much data, we should perhaps freeze some of the layers of the encoder... "encoder_freeze=False" | |
# Ps3: keras saves models into $ cd /home/<username>/.keras/models/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment