'Cannot understand Model Summary
I'd like to truly understand the summary of deep model architecture (model.summary). The goal is to ensure that the machine is processing everything exactly as I want it to. To do so, I need to know the model summary with respect to each line of code.
Here is the Code:
def conv1d(x, channels, ks=1, strides=1, padding='same'):
conv = tf.keras.layers.Conv1D(channels, ks, strides, padding, activation='relu', use_bias=False,
kernel_initializer='HeNormal')(x)
return conv
def my_self_attention(x, channels):
size = x.shape
x = tf.reshape(x, shape=[-1, x.shape[2], x.shape[3]])
f = conv1d(x, channels)
g = conv1d(x, channels)
h = conv1d(x, channels)
attention_weights = tf.keras.activations.softmax(
tf.matmul(g, Permute((2, 1))(f))) # query multiply with key and then softmax on it
sensor_att_fm = tf.matmul(attention_weights, h)
gamma = tf.compat.v1.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
o = gamma * sensor_att_fm + x
return tf.reshape(o, shape = [-1, 1, x.shape[1], x.shape[2]])
refined_fm = tf.concat([my_self_attention(tf.expand_dims(my_input[:, t, :, :], 1), channels) for t in range(my_input.shape[1])], 1)
Model Summary
multiply (Multiply) (None, 8, 6, 64) 0 dropout[0][0]
dropout_2[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem (Slici (None, 6, 64) 0 multiply[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_1 (Sli (None, 6, 64) 0 multiply[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_2 (Sli (None, 6, 64) 0 multiply[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_3 (Sli (None, 6, 64) 0 multiply[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_4 (Sli (None, 6, 64) 0 multiply[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_5 (Sli (None, 6, 64) 0 multiply[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_6 (Sli (None, 6, 64) 0 multiply[0][0]
__________________________________________________________________________________________________
tf.__operators__.getitem_7 (Sli (None, 6, 64) 0 multiply[0][0]
__________________________________________________________________________________________________
tf.expand_dims (TFOpLambda) (None, 1, 6, 64) 0 tf.__operators__.getitem[0][0]
__________________________________________________________________________________________________
tf.expand_dims_1 (TFOpLambda) (None, 1, 6, 64) 0 tf.__operators__.getitem_1[0][0]
__________________________________________________________________________________________________
tf.expand_dims_2 (TFOpLambda) (None, 1, 6, 64) 0 tf.__operators__.getitem_2[0][0]
__________________________________________________________________________________________________
tf.expand_dims_3 (TFOpLambda) (None, 1, 6, 64) 0 tf.__operators__.getitem_3[0][0]
__________________________________________________________________________________________________
tf.expand_dims_4 (TFOpLambda) (None, 1, 6, 64) 0 tf.__operators__.getitem_4[0][0]
__________________________________________________________________________________________________
tf.expand_dims_5 (TFOpLambda) (None, 1, 6, 64) 0 tf.__operators__.getitem_5[0][0]
__________________________________________________________________________________________________
tf.expand_dims_6 (TFOpLambda) (None, 1, 6, 64) 0 tf.__operators__.getitem_6[0][0]
__________________________________________________________________________________________________
tf.expand_dims_7 (TFOpLambda) (None, 1, 6, 64) 0 tf.__operators__.getitem_7[0][0]
__________________________________________________________________________________________________
tf.reshape (TFOpLambda) (None, 6, 64) 0 tf.expand_dims[0][0]
__________________________________________________________________________________________________
tf.reshape_2 (TFOpLambda) (None, 6, 64) 0 tf.expand_dims_1[0][0]
__________________________________________________________________________________________________
tf.reshape_4 (TFOpLambda) (None, 6, 64) 0 tf.expand_dims_2[0][0]
__________________________________________________________________________________________________
tf.reshape_6 (TFOpLambda) (None, 6, 64) 0 tf.expand_dims_3[0][0]
__________________________________________________________________________________________________
tf.reshape_8 (TFOpLambda) (None, 6, 64) 0 tf.expand_dims_4[0][0]
__________________________________________________________________________________________________
tf.reshape_10 (TFOpLambda) (None, 6, 64) 0 tf.expand_dims_5[0][0]
__________________________________________________________________________________________________
tf.reshape_12 (TFOpLambda) (None, 6, 64) 0 tf.expand_dims_6[0][0]
__________________________________________________________________________________________________
tf.reshape_14 (TFOpLambda) (None, 6, 64) 0 tf.expand_dims_7[0][0]
__________________________________________________________________________________________________
conv1d (Conv1D) (None, 6, 64) 4096 tf.reshape[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D) (None, 6, 64) 4096 tf.reshape_2[0][0]
__________________________________________________________________________________________________
conv1d_6 (Conv1D) (None, 6, 64) 4096 tf.reshape_4[0][0]
__________________________________________________________________________________________________
conv1d_9 (Conv1D) (None, 6, 64) 4096 tf.reshape_6[0][0]
__________________________________________________________________________________________________
conv1d_12 (Conv1D) (None, 6, 64) 4096 tf.reshape_8[0][0]
__________________________________________________________________________________________________
conv1d_15 (Conv1D) (None, 6, 64) 4096 tf.reshape_10[0][0]
__________________________________________________________________________________________________
conv1d_18 (Conv1D) (None, 6, 64) 4096 tf.reshape_12[0][0]
__________________________________________________________________________________________________
conv1d_21 (Conv1D) (None, 6, 64) 4096 tf.reshape_14[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, 6, 64) 4096 tf.reshape[0][0]
__________________________________________________________________________________________________
permute (Permute) (None, 64, 6) 0 conv1d[0][0]
__________________________________________________________________________________________________
conv1d_4 (Conv1D) (None, 6, 64) 4096 tf.reshape_2[0][0]
__________________________________________________________________________________________________
permute_1 (Permute) (None, 64, 6) 0 conv1d_3[0][0]
__________________________________________________________________________________________________
conv1d_7 (Conv1D) (None, 6, 64) 4096 tf.reshape_4[0][0]
__________________________________________________________________________________________________
permute_2 (Permute) (None, 64, 6) 0 conv1d_6[0][0]
__________________________________________________________________________________________________
conv1d_10 (Conv1D) (None, 6, 64) 4096 tf.reshape_6[0][0]
__________________________________________________________________________________________________
permute_3 (Permute) (None, 64, 6) 0 conv1d_9[0][0]
__________________________________________________________________________________________________
conv1d_13 (Conv1D) (None, 6, 64) 4096 tf.reshape_8[0][0]
__________________________________________________________________________________________________
permute_4 (Permute) (None, 64, 6) 0 conv1d_12[0][0]
__________________________________________________________________________________________________
conv1d_16 (Conv1D) (None, 6, 64) 4096 tf.reshape_10[0][0]
__________________________________________________________________________________________________
permute_5 (Permute) (None, 64, 6) 0 conv1d_15[0][0]
__________________________________________________________________________________________________
conv1d_19 (Conv1D) (None, 6, 64) 4096 tf.reshape_12[0][0]
__________________________________________________________________________________________________
permute_6 (Permute) (None, 64, 6) 0 conv1d_18[0][0]
__________________________________________________________________________________________________
conv1d_22 (Conv1D) (None, 6, 64) 4096 tf.reshape_14[0][0]
__________________________________________________________________________________________________
permute_7 (Permute) (None, 64, 6) 0 conv1d_21[0][0]
__________________________________________________________________________________________________
tf.linalg.matmul (TFOpLambda) (None, 6, 6) 0 conv1d_1[0][0]
permute[0][0]
Specifically I am confused at why convolution and reshaping again after Permute (Such as
permute_1 (Permute) (None, 64, 6) 0 conv1d_3[0][0]
_________________________________________________________________________________________________
conv1d_7 (Conv1D) (None, 6, 64) 4096 tf.reshape_4[0][0]
)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
