Created
December 26, 2020 12:19
-
-
Save yuntan/8198f80593b6897844236c5a5a7b07da to your computer and use it in GitHub Desktop.
DeBlurNet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.keras import Model, Sequential | |
from tensorflow.keras.layers import ( | |
Input, Add, Concatenate, Conv2D, Conv2DTranspose, BatchNormalization, ReLU, | |
Activation, | |
) | |
def conv_bn_relu(n_channels: int, kernel_size: int=3, strides: int=1, | |
name: str=None): | |
initializer = tf.random_normal_initializer(0., 0.02) | |
return Sequential([ | |
Conv2D(n_channels, kernel_size, strides, padding="same", | |
kernel_initializer=initializer, use_bias=False), | |
BatchNormalization(), | |
ReLU(), | |
], name) | |
def down(n_channels: int, name: str=None): | |
return conv_bn_relu(n_channels, strides=2, name=name) | |
def up(n_channels: int): | |
initializer = tf.random_normal_initializer(0., 0.02) | |
def ret(x, skip): | |
x = Conv2DTranspose(n_channels, kernel_size=4, strides=2, | |
padding="same", kernel_initializer=initializer, | |
use_bias=False)(x) | |
x = BatchNormalization()(x) | |
x = Add()([x, skip]) | |
x = ReLU()(x) | |
return x | |
return ret | |
def last(): | |
initializer = tf.random_normal_initializer(0., 0.02) | |
def ret(x, skip): | |
x = Conv2DTranspose(3, kernel_size=3, strides=1, padding="same", | |
kernel_initializer=initializer, use_bias=False)(x) | |
x = BatchNormalization()(x) | |
x = Add()([x, skip]) | |
x = Activation("sigmoid")(x) | |
return x | |
return ret | |
def deblurnet(): | |
skips = [] | |
inputs = Input(shape=[None, None, 15]) | |
# inputs = Input(shape=[128, 128, 15]) # for debug | |
x = inputs | |
skips.append(inputs[:, :, :, 6:9]) | |
x = conv_bn_relu(64, kernel_size=5, name="F0")(x) # F0 | |
skips.append(x) | |
x = down(64, name="D1")(x) # D1 | |
x = conv_bn_relu(128, name="F1_1")(x) # F1_1 | |
x = conv_bn_relu(128, name="F1_2")(x) # F1_2 | |
skips.append(x) | |
x = down(256, name="D2")(x) # D2 | |
x = conv_bn_relu(256, name="F2_1")(x) # F2_1 | |
x = conv_bn_relu(256, name="F2_2")(x) # F2_2 | |
x = conv_bn_relu(256, name="F2_3")(x) # F2_3 | |
skips.append(x) | |
x = down(512, name="D3")(x) # D3 | |
x = conv_bn_relu(512, name="F3_1")(x) # F3_1 | |
x = conv_bn_relu(512, name="F3_2")(x) # F3_2 | |
x = conv_bn_relu(512, name="F3_3")(x) # F3_3 | |
x = up(256)(x, skips.pop()) # U1 | |
x = conv_bn_relu(256, name="F4_1")(x) # F4_1 | |
x = conv_bn_relu(256, name="F4_2")(x) # F4_2 | |
x = conv_bn_relu(256, name="F4_3")(x) # F4_3 | |
x = up(128)(x, skips.pop()) # U2 | |
x = conv_bn_relu(128, name="F5_1")(x) # F5_1 | |
x = conv_bn_relu(64, name="F5_2")(x) # F5_2 | |
x = up(64)(x, skips.pop()) # U3 | |
x = conv_bn_relu(15, name="F6_1")(x) # F6_1 | |
x = last()(x, skips.pop()) # F6_2 | |
return Model(inputs=inputs, outputs=x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment