Skip to content

Instantly share code, notes, and snippets.

@innat
Created November 16, 2021 18:04
Show Gist options
  • Save innat/437a5b1a72dad9f38b405102efe0a697 to your computer and use it in GitHub Desktop.
Save innat/437a5b1a72dad9f38b405102efe0a697 to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras import layers
class ConvoBlocks(tf.keras.layers.Layer):
def __init__(self, num_filters=256,
kernel_size=3, dilation_rate=1,
padding="same", use_bias=False, **kwargs):
super(ConvoBlocks, self).__init__(**kwargs)
self.num_filters = num_filters
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.padding = padding
self.use_bias = use_bias
self.conv = layers.Conv2D(
filters = self.num_filters,
kernel_size = self.kernel_size,
dilation_rate = self.dilation_rate,
padding = self.padding,
use_bias = self.use_bias
)
self.bn = layers.BatchNormalization()
def call(self, inputs, training=None):
x = self.conv(inputs)
x = self.bn(x)
return tf.nn.relu(x)
x = ConvoBlocks()
x(tf.ones((1, 100, 100, 3))).shape
class DilatedSpatialPyramidPooling(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(DilatedSpatialPyramidPooling, self).__init__(**kwargs)
self.conv_block_a = ConvoBlocks(kernel_size = 1, use_bias=True)
self.conv_block_b = ConvoBlocks(kernel_size = 1, dilation_rate = 1)
self.conv_block_c = ConvoBlocks(kernel_size = 3, dilation_rate = 6)
self.conv_block_d = ConvoBlocks(kernel_size = 3, dilation_rate = 12)
self.conv_block_e = ConvoBlocks(kernel_size = 3, dilation_rate = 18)
self.conv_block_f = ConvoBlocks(kernel_size = 1)
def call(self, inputs, training=None):
b, h, w, c = inputs.get_shape().as_list()
x = layers.AveragePooling2D(pool_size=(h, w))(inputs)
x = self.conv_block_a(x)
a = layers.UpSampling2D(size=(h // x.shape[1], w // x.shape[2]), interpolation='bilinear')(x)
b = self.conv_block_b(inputs)
c = self.conv_block_c(inputs)
d = self.conv_block_d(inputs)
e = self.conv_block_e(inputs)
cats = layers.Concatenate(axis=-1)([a, b, c, d, e])
outs = self.conv_block_f(cats)
return outs
d = DilatedSpatialPyramidPooling()
d(tf.ones((1, 100, 100, 3))).shape
image_size = 512
class DeeplabV3Plus(tf.keras.Model):
def __init__(self, num_classes=10, **kwargs):
super(DeeplabV3Plus, self).__init__(**kwargs)
self.num_classes = num_classes
base = keras.applications.ResNet50(
weights="imagenet",
include_top=False,
input_tensor=keras.Input(shape=(image_size, image_size, 3)))
self.new_base = Model(base.input,
[
base.get_layer("conv4_block6_2_relu").output,
base.get_layer("conv2_block3_2_relu").output
])
self.dsp = DilatedSpatialPyramidPooling()
self.conv_block_a = ConvoBlocks(num_filters = 48, kernel_size = 1)
self.conv_block_b = ConvoBlocks()
self.conv_block_c = ConvoBlocks()
self.last = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")
def call(self, inputs, training=None):
conv4_b6, conv2_b3 = self.new_base(inputs)
x = self.dsp(conv4_b6)
input_a = layers.UpSampling2D(size=(image_size // 4 // x.shape[1],
image_size // 4 // x.shape[2]),
interpolation="bilinear")(x)
input_b = self.conv_block_a(conv2_b3)
x = layers.Concatenate(axis=-1)([input_a, input_b])
x = self.conv_block_b(x)
x = self.conv_block_c(x)
x = layers.UpSampling2D(size=(image_size // x.shape[1],
image_size // x.shape[2]),
interpolation="bilinear")(x)
outs = self.last(x)
return outs
def build_graph(self):
x = tf.keras.Input(shape=(image_size, image_size, 3))
return Model(inputs=[x], outputs=self.call(x))
model = DeeplabV3Plus(num_classes=10)
model(tf.ones((1, 224, 224, 3))).shape
tf.keras.utils.plot_model(model.build_graph(), expand_nested=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment