'Cannot understand Model Summary

I'd like to truly understand the summary of deep model architecture (model.summary). The goal is to ensure that the machine is processing everything exactly as I want it to. To do so, I need to know the model summary with respect to each line of code.

Here is the Code:

def conv1d(x, channels, ks=1, strides=1, padding='same'):
    conv = tf.keras.layers.Conv1D(channels, ks, strides, padding, activation='relu', use_bias=False,
                                  kernel_initializer='HeNormal')(x)
    return conv


def my_self_attention(x, channels):
    size = x.shape
    x = tf.reshape(x, shape=[-1, x.shape[2], x.shape[3]])
    f = conv1d(x, channels)
    g = conv1d(x, channels)
    h = conv1d(x, channels)
    attention_weights = tf.keras.activations.softmax(
        tf.matmul(g, Permute((2, 1))(f)))  # query multiply with key and then softmax on it
    sensor_att_fm = tf.matmul(attention_weights, h)
    gamma = tf.compat.v1.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
    o = gamma * sensor_att_fm + x
    return tf.reshape(o, shape = [-1, 1, x.shape[1], x.shape[2]])

refined_fm = tf.concat([my_self_attention(tf.expand_dims(my_input[:, t, :, :], 1), channels) for t in range(my_input.shape[1])], 1)

Model Summary

 multiply (Multiply)             (None, 8, 6, 64)     0           dropout[0][0]                    
                                                                 dropout_2[0][0]                  
__________________________________________________________________________________________________
tf.__operators__.getitem (Slici (None, 6, 64)        0           multiply[0][0]                   
__________________________________________________________________________________________________
tf.__operators__.getitem_1 (Sli (None, 6, 64)        0           multiply[0][0]                   
__________________________________________________________________________________________________
tf.__operators__.getitem_2 (Sli (None, 6, 64)        0           multiply[0][0]                   
__________________________________________________________________________________________________
tf.__operators__.getitem_3 (Sli (None, 6, 64)        0           multiply[0][0]                   
__________________________________________________________________________________________________
tf.__operators__.getitem_4 (Sli (None, 6, 64)        0           multiply[0][0]                   
__________________________________________________________________________________________________
tf.__operators__.getitem_5 (Sli (None, 6, 64)        0           multiply[0][0]                   
__________________________________________________________________________________________________
tf.__operators__.getitem_6 (Sli (None, 6, 64)        0           multiply[0][0]                   
__________________________________________________________________________________________________
tf.__operators__.getitem_7 (Sli (None, 6, 64)        0           multiply[0][0]                   
__________________________________________________________________________________________________
tf.expand_dims (TFOpLambda)     (None, 1, 6, 64)     0           tf.__operators__.getitem[0][0]   
__________________________________________________________________________________________________
tf.expand_dims_1 (TFOpLambda)   (None, 1, 6, 64)     0           tf.__operators__.getitem_1[0][0] 
__________________________________________________________________________________________________
tf.expand_dims_2 (TFOpLambda)   (None, 1, 6, 64)     0           tf.__operators__.getitem_2[0][0] 
__________________________________________________________________________________________________
tf.expand_dims_3 (TFOpLambda)   (None, 1, 6, 64)     0           tf.__operators__.getitem_3[0][0] 
__________________________________________________________________________________________________
tf.expand_dims_4 (TFOpLambda)   (None, 1, 6, 64)     0           tf.__operators__.getitem_4[0][0] 
__________________________________________________________________________________________________
tf.expand_dims_5 (TFOpLambda)   (None, 1, 6, 64)     0           tf.__operators__.getitem_5[0][0] 
__________________________________________________________________________________________________
tf.expand_dims_6 (TFOpLambda)   (None, 1, 6, 64)     0           tf.__operators__.getitem_6[0][0] 
__________________________________________________________________________________________________
tf.expand_dims_7 (TFOpLambda)   (None, 1, 6, 64)     0           tf.__operators__.getitem_7[0][0] 
__________________________________________________________________________________________________
tf.reshape (TFOpLambda)         (None, 6, 64)        0           tf.expand_dims[0][0]             
__________________________________________________________________________________________________
tf.reshape_2 (TFOpLambda)       (None, 6, 64)        0           tf.expand_dims_1[0][0]           
__________________________________________________________________________________________________
tf.reshape_4 (TFOpLambda)       (None, 6, 64)        0           tf.expand_dims_2[0][0]           
__________________________________________________________________________________________________
tf.reshape_6 (TFOpLambda)       (None, 6, 64)        0           tf.expand_dims_3[0][0]           
__________________________________________________________________________________________________
tf.reshape_8 (TFOpLambda)       (None, 6, 64)        0           tf.expand_dims_4[0][0]           
__________________________________________________________________________________________________
tf.reshape_10 (TFOpLambda)      (None, 6, 64)        0           tf.expand_dims_5[0][0]           
__________________________________________________________________________________________________
tf.reshape_12 (TFOpLambda)      (None, 6, 64)        0           tf.expand_dims_6[0][0]           
__________________________________________________________________________________________________
tf.reshape_14 (TFOpLambda)      (None, 6, 64)        0           tf.expand_dims_7[0][0]           
__________________________________________________________________________________________________
conv1d (Conv1D)                 (None, 6, 64)        4096        tf.reshape[0][0]                 
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, 6, 64)        4096        tf.reshape_2[0][0]               
__________________________________________________________________________________________________
conv1d_6 (Conv1D)               (None, 6, 64)        4096        tf.reshape_4[0][0]               
__________________________________________________________________________________________________
conv1d_9 (Conv1D)               (None, 6, 64)        4096        tf.reshape_6[0][0]               
__________________________________________________________________________________________________
conv1d_12 (Conv1D)              (None, 6, 64)        4096        tf.reshape_8[0][0]               
__________________________________________________________________________________________________
conv1d_15 (Conv1D)              (None, 6, 64)        4096        tf.reshape_10[0][0]              
__________________________________________________________________________________________________
conv1d_18 (Conv1D)              (None, 6, 64)        4096        tf.reshape_12[0][0]              
__________________________________________________________________________________________________
conv1d_21 (Conv1D)              (None, 6, 64)        4096        tf.reshape_14[0][0]              
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, 6, 64)        4096        tf.reshape[0][0]                 
__________________________________________________________________________________________________
permute (Permute)               (None, 64, 6)        0           conv1d[0][0]                     
__________________________________________________________________________________________________
conv1d_4 (Conv1D)               (None, 6, 64)        4096        tf.reshape_2[0][0]               
__________________________________________________________________________________________________
permute_1 (Permute)             (None, 64, 6)        0           conv1d_3[0][0]                   
__________________________________________________________________________________________________
conv1d_7 (Conv1D)               (None, 6, 64)        4096        tf.reshape_4[0][0]               
__________________________________________________________________________________________________
permute_2 (Permute)             (None, 64, 6)        0           conv1d_6[0][0]                   
__________________________________________________________________________________________________
conv1d_10 (Conv1D)              (None, 6, 64)        4096        tf.reshape_6[0][0]               
__________________________________________________________________________________________________
permute_3 (Permute)             (None, 64, 6)        0           conv1d_9[0][0]                   
__________________________________________________________________________________________________
conv1d_13 (Conv1D)              (None, 6, 64)        4096        tf.reshape_8[0][0]               
__________________________________________________________________________________________________
permute_4 (Permute)             (None, 64, 6)        0           conv1d_12[0][0]                  
__________________________________________________________________________________________________
conv1d_16 (Conv1D)              (None, 6, 64)        4096        tf.reshape_10[0][0]              
__________________________________________________________________________________________________
permute_5 (Permute)             (None, 64, 6)        0           conv1d_15[0][0]                  
__________________________________________________________________________________________________
conv1d_19 (Conv1D)              (None, 6, 64)        4096        tf.reshape_12[0][0]              
__________________________________________________________________________________________________
permute_6 (Permute)             (None, 64, 6)        0           conv1d_18[0][0]                  
__________________________________________________________________________________________________
conv1d_22 (Conv1D)              (None, 6, 64)        4096        tf.reshape_14[0][0]              
__________________________________________________________________________________________________
permute_7 (Permute)             (None, 64, 6)        0           conv1d_21[0][0]                  
__________________________________________________________________________________________________
tf.linalg.matmul (TFOpLambda)   (None, 6, 6)         0           conv1d_1[0][0]                   
                                                                 permute[0][0]                    

Specifically I am confused at why convolution and reshaping again after Permute (Such as

permute_1 (Permute)             (None, 64, 6)        0           conv1d_3[0][0]                   
_________________________________________________________________________________________________
conv1d_7 (Conv1D)               (None, 6, 64)        4096        tf.reshape_4[0][0] 

)



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source