Skip to content

Instantly share code, notes, and snippets.

@innat
Created November 12, 2021 22:43
Show Gist options
  • Save innat/9c4a6a7a1cd152f0d67cb4fe5e59926c to your computer and use it in GitHub Desktop.
Save innat/9c4a6a7a1cd152f0d67cb4fe5e59926c to your computer and use it in GitHub Desktop.
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras
print('TensorFlow', tf.__version__)
class ResidualBlock(layers.Layer):
def __init__(self, block_type=None, n_filters=None):
super(ResidualBlock, self).__init__()
self.n_filters = n_filters
if block_type == 'identity':
self.strides = 1
elif block_type == 'conv':
self.strides = 2
self.conv_shorcut = layers.Conv2D(filters=self.n_filters,
kernel_size=1,
padding='same',
strides=self.strides,
kernel_initializer='he_normal')
self.bn_shortcut = layers.BatchNormalization(momentum=0.9)
self.conv_1 = layers.Conv2D(filters=self.n_filters,
kernel_size=3,
padding='same',
strides=self.strides,
kernel_initializer='he_normal')
self.conv_2 = layers.Conv2D(filters=self.n_filters,
kernel_size=3,
padding='same',
kernel_initializer='he_normal')
self.bn_1 = layers.BatchNormalization(momentum=0.9)
self.bn_2 = layers.BatchNormalization(momentum=0.9)
def call(self, inputs, training=False):
shortcut = inputs
if self.strides == 2:
shortcut = self.conv_shorcut(inputs)
shortcut = self.bn_shortcut(shortcut)
x = self.conv_1(inputs)
x = tf.nn.relu(self.bn_1(x))
x = self.conv_2(x)
x = self.bn_2(x)
x = tf.add(shortcut, x)
return tf.nn.relu(x)
class ResNet34(tf.keras.Model):
def __init__(self, include_top=True, n_classes=1000):
super(ResNet34, self).__init__()
self.n_classes = n_classes
self.include_top = include_top
self.conv_1 = layers.Conv2D(filters=64,
kernel_size=7,
padding='same',
strides=2,
kernel_initializer='he_normal')
self.bn_1 = layers.BatchNormalization(momentum=0.9)
self.zero_pad_1 = layers.ZeroPadding2D(padding=(3, 3))
self.bn_2 = layers.BatchNormalization(momentum=0.9)
self.zero_pad_2 = layers.ZeroPadding2D(padding=(1, 1))
self.maxpool = layers.MaxPool2D(3, 2, padding='same')
self.residual_blocks = keras.Sequential()
for n_filters, reps, downscale in zip([64, 128, 256, 512],
[3, 4, 6, 3],
[False, True, True, True]):
for i in range(reps):
if i == 0 and downscale:
self.residual_blocks.add(ResidualBlock(block_type='conv', n_filters=n_filters))
else:
self.residual_blocks.add(ResidualBlock(block_type='identity', n_filters=n_filters))
self.gap = layers.GlobalAveragePooling2D()
self.fc =layers.Dense(units=self.n_classes)
def call(self, inputs, training=False):
x = self.bn_1(inputs)
x = self.zero_pad_1(x)
x = self.conv_1(x)
x = tf.nn.relu(self.bn_2(x))
x = self.zero_pad_2(x)
x = self.maxpool(x)
x = self.residual_blocks(x)
if self.include_top:
x = self.gap(x)
x = self.fc(x)
return x
def build_graph(self):
x = tf.keras.Input(shape=(224, 224, 3))
return tf.keras.Model(inputs=[x], outputs=self.call(x))
model = ResNet34()
model.build((1, 224, 224, 3))
model.build_graph().summary(line_length=120, expand_nested=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment