Skip to content

Instantly share code, notes, and snippets.

@innat
Last active October 14, 2023 19:03
Show Gist options
  • Save innat/99888fa8065ecbf3ae2b297e5c10db70 to your computer and use it in GitHub Desktop.
Save innat/99888fa8065ecbf3ae2b297e5c10db70 to your computer and use it in GitHub Desktop.
TF.Keras Implementation of Convolutional Block Attention Module (CBAM)
class SpatialAttentionModule(tf.keras.layers.Layer):
def __init__(self, kernel_size=3):
'''
paper: https://arxiv.org/abs/1807.06521
code: https://gist.github.com/innat/99888fa8065ecbf3ae2b297e5c10db70
'''
super(SpatialAttentionModule, self).__init__()
self.conv1 = tf.keras.layers.Conv2D(64, kernel_size=kernel_size,
use_bias=False,
kernel_initializer='he_normal',
strides=1, padding='same',
activation=tf.nn.relu)
self.conv2 = tf.keras.layers.Conv2D(32, kernel_size=kernel_size,
use_bias=False,
kernel_initializer='he_normal',
strides=1, padding='same',
activation=tf.nn.relu)
self.conv3 = tf.keras.layers.Conv2D(16, kernel_size=kernel_size,
use_bias=False,
kernel_initializer='he_normal',
strides=1, padding='same',
activation=tf.nn.relu)
self.conv4 = tf.keras.layers.Conv2D(1, kernel_size=kernel_size,
use_bias=False,
kernel_initializer='he_normal',
strides=1, padding='same',
activation=tf.math.sigmoid)
def call(self, inputs):
avg_out = tf.reduce_mean(inputs, axis=3)
max_out = tf.reduce_max(inputs, axis=3)
x = tf.stack([avg_out, max_out], axis=3)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return self.conv4(x)
class ChannelAttentionModule(tf.keras.layers.Layer):
def __init__(self, ratio=8):
'''
paper: https://arxiv.org/abs/1807.06521
code: https://gist.github.com/innat/99888fa8065ecbf3ae2b297e5c10db70
'''
super(ChannelAttentionModule, self).__init__()
self.ratio = ratio
self.gapavg = tf.keras.layers.GlobalAveragePooling2D()
self.gmpmax = tf.keras.layers.GlobalMaxPooling2D()
def build(self, input_shape):
self.conv1 = tf.keras.layers.Conv2D(input_shape[-1]//self.ratio,
kernel_size=1,
strides=1, padding='same',
use_bias=True, activation=tf.nn.relu)
self.conv2 = tf.keras.layers.Conv2D(input_shape[-1],
kernel_size=1,
strides=1, padding='same',
use_bias=True, activation=tf.nn.relu)
super(ChannelAttentionModule, self).build(input_shape)
def call(self, inputs):
# compute gap and gmp pooling
gapavg = self.gapavg(inputs)
gmpmax = self.gmpmax(inputs)
gapavg = tf.keras.layers.Reshape((1, 1, gapavg.shape[1]))(gapavg)
gmpmax = tf.keras.layers.Reshape((1, 1, gmpmax.shape[1]))(gmpmax)
# forward passing to the respected layers
gapavg_out = self.conv2(self.conv1(gapavg))
gmpmax_out = self.conv2(self.conv1(gmpmax))
return tf.math.sigmoid(gapavg_out + gmpmax_out)
def get_output_shape_for(self, input_shape):
return self.compute_output_shape(input_shape)
def compute_output_shape(self, input_shape):
output_len = input_shape[3]
return (input_shape[0], output_len)
@innat
Copy link
Author

innat commented Aug 7, 2021

Note

The above CBAM layer is used to integrate with the EfficientNet model in order to add an Attention Mechanism. Check out the following notebook.

Multi-Attention EfficientNe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment