Skip to content

Instantly share code, notes, and snippets.

@yuntan
Created December 26, 2020 12:19
Show Gist options
  • Save yuntan/8198f80593b6897844236c5a5a7b07da to your computer and use it in GitHub Desktop.
Save yuntan/8198f80593b6897844236c5a5a7b07da to your computer and use it in GitHub Desktop.
DeBlurNet
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import (
Input, Add, Concatenate, Conv2D, Conv2DTranspose, BatchNormalization, ReLU,
Activation,
)
def conv_bn_relu(n_channels: int, kernel_size: int=3, strides: int=1,
name: str=None):
initializer = tf.random_normal_initializer(0., 0.02)
return Sequential([
Conv2D(n_channels, kernel_size, strides, padding="same",
kernel_initializer=initializer, use_bias=False),
BatchNormalization(),
ReLU(),
], name)
def down(n_channels: int, name: str=None):
return conv_bn_relu(n_channels, strides=2, name=name)
def up(n_channels: int):
initializer = tf.random_normal_initializer(0., 0.02)
def ret(x, skip):
x = Conv2DTranspose(n_channels, kernel_size=4, strides=2,
padding="same", kernel_initializer=initializer,
use_bias=False)(x)
x = BatchNormalization()(x)
x = Add()([x, skip])
x = ReLU()(x)
return x
return ret
def last():
initializer = tf.random_normal_initializer(0., 0.02)
def ret(x, skip):
x = Conv2DTranspose(3, kernel_size=3, strides=1, padding="same",
kernel_initializer=initializer, use_bias=False)(x)
x = BatchNormalization()(x)
x = Add()([x, skip])
x = Activation("sigmoid")(x)
return x
return ret
def deblurnet():
skips = []
inputs = Input(shape=[None, None, 15])
# inputs = Input(shape=[128, 128, 15]) # for debug
x = inputs
skips.append(inputs[:, :, :, 6:9])
x = conv_bn_relu(64, kernel_size=5, name="F0")(x) # F0
skips.append(x)
x = down(64, name="D1")(x) # D1
x = conv_bn_relu(128, name="F1_1")(x) # F1_1
x = conv_bn_relu(128, name="F1_2")(x) # F1_2
skips.append(x)
x = down(256, name="D2")(x) # D2
x = conv_bn_relu(256, name="F2_1")(x) # F2_1
x = conv_bn_relu(256, name="F2_2")(x) # F2_2
x = conv_bn_relu(256, name="F2_3")(x) # F2_3
skips.append(x)
x = down(512, name="D3")(x) # D3
x = conv_bn_relu(512, name="F3_1")(x) # F3_1
x = conv_bn_relu(512, name="F3_2")(x) # F3_2
x = conv_bn_relu(512, name="F3_3")(x) # F3_3
x = up(256)(x, skips.pop()) # U1
x = conv_bn_relu(256, name="F4_1")(x) # F4_1
x = conv_bn_relu(256, name="F4_2")(x) # F4_2
x = conv_bn_relu(256, name="F4_3")(x) # F4_3
x = up(128)(x, skips.pop()) # U2
x = conv_bn_relu(128, name="F5_1")(x) # F5_1
x = conv_bn_relu(64, name="F5_2")(x) # F5_2
x = up(64)(x, skips.pop()) # U3
x = conv_bn_relu(15, name="F6_1")(x) # F6_1
x = last()(x, skips.pop()) # F6_2
return Model(inputs=inputs, outputs=x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment