Last active
October 14, 2023 19:03
-
-
Save innat/99888fa8065ecbf3ae2b297e5c10db70 to your computer and use it in GitHub Desktop.
TF.Keras Implementation of Convolutional Block Attention Module (CBAM)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SpatialAttentionModule(tf.keras.layers.Layer): | |
def __init__(self, kernel_size=3): | |
''' | |
paper: https://arxiv.org/abs/1807.06521 | |
code: https://gist.github.com/innat/99888fa8065ecbf3ae2b297e5c10db70 | |
''' | |
super(SpatialAttentionModule, self).__init__() | |
self.conv1 = tf.keras.layers.Conv2D(64, kernel_size=kernel_size, | |
use_bias=False, | |
kernel_initializer='he_normal', | |
strides=1, padding='same', | |
activation=tf.nn.relu) | |
self.conv2 = tf.keras.layers.Conv2D(32, kernel_size=kernel_size, | |
use_bias=False, | |
kernel_initializer='he_normal', | |
strides=1, padding='same', | |
activation=tf.nn.relu) | |
self.conv3 = tf.keras.layers.Conv2D(16, kernel_size=kernel_size, | |
use_bias=False, | |
kernel_initializer='he_normal', | |
strides=1, padding='same', | |
activation=tf.nn.relu) | |
self.conv4 = tf.keras.layers.Conv2D(1, kernel_size=kernel_size, | |
use_bias=False, | |
kernel_initializer='he_normal', | |
strides=1, padding='same', | |
activation=tf.math.sigmoid) | |
def call(self, inputs): | |
avg_out = tf.reduce_mean(inputs, axis=3) | |
max_out = tf.reduce_max(inputs, axis=3) | |
x = tf.stack([avg_out, max_out], axis=3) | |
x = self.conv1(x) | |
x = self.conv2(x) | |
x = self.conv3(x) | |
return self.conv4(x) | |
class ChannelAttentionModule(tf.keras.layers.Layer): | |
def __init__(self, ratio=8): | |
''' | |
paper: https://arxiv.org/abs/1807.06521 | |
code: https://gist.github.com/innat/99888fa8065ecbf3ae2b297e5c10db70 | |
''' | |
super(ChannelAttentionModule, self).__init__() | |
self.ratio = ratio | |
self.gapavg = tf.keras.layers.GlobalAveragePooling2D() | |
self.gmpmax = tf.keras.layers.GlobalMaxPooling2D() | |
def build(self, input_shape): | |
self.conv1 = tf.keras.layers.Conv2D(input_shape[-1]//self.ratio, | |
kernel_size=1, | |
strides=1, padding='same', | |
use_bias=True, activation=tf.nn.relu) | |
self.conv2 = tf.keras.layers.Conv2D(input_shape[-1], | |
kernel_size=1, | |
strides=1, padding='same', | |
use_bias=True, activation=tf.nn.relu) | |
super(ChannelAttentionModule, self).build(input_shape) | |
def call(self, inputs): | |
# compute gap and gmp pooling | |
gapavg = self.gapavg(inputs) | |
gmpmax = self.gmpmax(inputs) | |
gapavg = tf.keras.layers.Reshape((1, 1, gapavg.shape[1]))(gapavg) | |
gmpmax = tf.keras.layers.Reshape((1, 1, gmpmax.shape[1]))(gmpmax) | |
# forward passing to the respected layers | |
gapavg_out = self.conv2(self.conv1(gapavg)) | |
gmpmax_out = self.conv2(self.conv1(gmpmax)) | |
return tf.math.sigmoid(gapavg_out + gmpmax_out) | |
def get_output_shape_for(self, input_shape): | |
return self.compute_output_shape(input_shape) | |
def compute_output_shape(self, input_shape): | |
output_len = input_shape[3] | |
return (input_shape[0], output_len) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Note
The above CBAM layer is used to integrate with the EfficientNet model in order to add an Attention Mechanism. Check out the following notebook.
Multi-Attention EfficientNe