'Understand and Implement Element-Wise Attention Module
Please add a minimum comment on your thoughts so that I can improve my query. Thank you. -)
I'm trying to understand and implement a research work on Triple Attention Learning, which consists on
- channel-wise attention (a)
- element-wise attention (b)
- scale-wise attention (c)
The mechanism is integrated experimentally inside the DenseNet model. The arch of the whole model's diagram is here. The channel-wise attention module is simply nothing but the squeeze and excitation block. That gives a sigmoid output further to the element-wise attention module. Below is the more precise feature flow diagram of these modules (a, b, and c).

Theory
For the most part, I was able to understand and implement it but was a bit lost in the Element-Wise attention section (part b from the above diagram). This is where I need your assistance. -)
Here is a little theory on this topic to give you a rough idea of what all this is about. Please note, The paper is not openly accessible now but at its early stage of release on the publisher page it was free to get and I saved it at that time. And to be fair to all, I'm sharing it with you, Link. Anyway, from the paper (Section 4.3) it shows:

So first of all, f(att) function (which is in the first inplace diagram, left-middle part or b) consists of three convolution layers with 512 kernels with 1 x 1, 512 kernels with 3 x 3 and C kernels with 1 x 1. Here C is the number of the classifier. And with Softmax activation!
Next, it applies to the Channel-Wise attention module which we mentioned that simply a SENet module and gave a sigmoid probability score i.e X(CA). So, from the function of f(att), we're getting C times softmax probability scores and each of these scores get multiplied with sigmoid output and finally produces feature maps A (according to the equation 4 of the above diagram).
Second, there is a C linear classifier that implemented as a 1 x 1 - C kernels convolution layer. This layer also applied to the SENet module's output i.e. X(CA), to each feature vector pixel-wise. And in the end, it gives an output of feature maps S (equation 5 shown below diagram).
And Third, they element-wise multiply each confidence score (of S) with the corresponding attention element A. This multiplication is on purpose. They did it for preventing unnecessary attention on the feature maps. To make it effective, they also use the weighted cross-entropy loss function to minimize it here between the classification ground truth and the score vector.

My Query
Mostly I don't get properly the minimization strategies in the middle of the network. I want someone who can give me a proper understanding and implementation of this `element-wise attention mechanism in detail that proposed in the mentioned paperwork (section 4.3).
Implement
Here is a minimum code to get started. It should enough I guess. This is shallow implementation but too much away from the original element-wise module. I'm not sure how to implement it properly. For now, I want it as a layer that supposed to plug and play to any model. I was trying with MNIST and a simple Conv net.
In a summary, for MNIST, we should have a network that contains both the channel-wise and element-wise attention model followed by the last 10 unit softmax layer. So for example:
Net: Conv2D - Attentions-Module - GAP - Softmax(10)
The Attention-Module consists of those two-part: Channel-wise and Element-wise, and the Element-wisesupposed to have Softmax too that minimizes weighted CE loss function to ground-truth and score vector coming from this module (according to the paperwork, already described above too). The module also passes weighted feature maps to the consecutive layers. For more clarity here is a simple schematic diagram of what we're looking for

Ok, for the channel-wise attention which should give us a single probability score (sigmoid), let's use a fake layer for now for simplicity:
class FakeSE(tf.keras.layers.Layer):
def __init__(self):
super(Block, self).__init__()
# conv layer
self.conv = tf.keras.layers.Conv2D(10, padding='same',
kernel_size=3)
def call(self, input_tensor, training=False):
x = self.conv(input_tensor)
return tf.math.sigmoid(x)
And for the element-wise attention part, following is the failed attempt so far:
class ElementWiseAttention(tf.keras.layers.Layer):
def __init__(self):
# for simplicity the f(attn) function here has 2 convolution instead of 3
# self.conv1, and self.conv2
self.conv1 = tf.keras.layers.Conv2D(16,
kernel_size=1,
strides=1, padding='same',
use_bias=True, activation=tf.nn.silu)
self.conv2 = tf.keras.layers.Conv2D(10,
kernel_size=1,
strides=1, padding='same',
use_bias=False, activation=tf.keras.activations.softmax)
# fake SENet or channel-wise attention module
self.cam = FakeSE()
# a linear layer
self.linear = tf.keras.layers.Conv2D(10,
kernel_size=1,
strides=1, padding='same',
use_bias=True, activation=None)
super(ElementWiseAttention, self).__init__()
def call(self, inputs):
# 2 stacked conv layer (in paper, it's 3. we set 2 for simplicity)
# this is the f(att)
x = self.conv1(inputs)
x = self.conv2(x)
# this is the A = f(att)*X(CA)
camx = self.cam(x)*x
# this is S = X(CA)*Linear_Classifier
linx = self.cam(self.linear(inputs))
# element-wise multiply to prevent unnecessary attention
# suppose to minimize with weighted cross entorpy loss
out = tf.multiply(camx, linx)
return out
The above one is the Layer of Interest. If I understand the paper words correctly, this layer should not only minimize the weighted loss function to gt and score_vector but also produce some weighted feature maps (2D).
Run
Here is the toy data
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, axis=-1)
x_train = x_train.astype('float32') / 255
x_train = tf.image.resize(x_train, [32,32]) # if we want to resize
y_train = tf.keras.utils.to_categorical(y_train , num_classes=10)
# Model
input = tf.keras.Input(shape=(32,32,1))
efnet = tf.keras.applications.DenseNet121(weights=None,
include_top = False,
input_tensor = input)
em = ElementWiseAttention()(efnet.output)
# Now that we apply global max pooling.
gap = tf.keras.layers.GlobalMaxPooling2D()(em)
# classification layer.
output = tf.keras.layers.Dense(10, activation='softmax')(gap)
# bind all
func_model = tf.keras.Model(efnet.input, output)
func_model.compile(
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = tf.keras.metrics.CategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam())
# fit
func_model.fit(x_train, y_train, batch_size=32, epochs=3, verbose = 1)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
