'Tensorflow keras BatchNormalization for higher than 4-dimension Tensor (video input)

I'm trying to implement S3D[https://arxiv.org/pdf/1712.04851.pdf] for video classification and I encountered a problem with BatchNormalization.

Since the implementation that I'm dealing with is video classification, I need an additional temporal dimension for my input tensor. (i.e. [Batch, Time, Height, Width, Channel])

Here's my error situation.

example = np.random.randint(0,255, (16,16,56,56,3)) 
example_tensor = tf.convert_to_tensor(example, dtype=tf.float32)

print(tf.keras.layers.BatchNormalization(axis=0)(example_tensor))
#print(tf.keras.layers.BatchNormalization(axis=1)(example_tensor)) # This gives error
print(tf.keras.layers.BatchNormalization(axis=2)(example_tensor))
print(tf.keras.layers.BatchNormalization(axis=3)(example_tensor))
#print(tf.keras.layers.BatchNormalization(axis=-1)(example_tensor)) # This gives error

And the error message is like this.

InvalidArgumentError: Exception encountered when calling layer "batch_normalization_56" (type BatchNormalization).

input must be 4-dimensional[16,16,56,56,3] [Op:FusedBatchNormV3]

Call arguments received:
  • inputs=tf.Tensor(shape=(16, 16, 56, 56, 3), dtype=float32)
  • training=None

I read about the meaning of axis in BatchNormalization from this stackoverflow question here But I still don't understand why my BatchNormalization code gives an error depending on what axis I give as an argument.

Also, I've searched a lot of questions and read tensorflow BatchNormalization document. [link]

I think this error message is telling me that it's expecting 4-dimensional input like we usually do for image processing([Batch, Height, Width, Channel])

Can anyone know what is happening here? and how to use BatchNormalization for a 5-dimension Tensor?



Solution 1:[1]

I'm also experiencing the same issue also on the m1 chip with tf version 2.7.0. For me, my dataset has dimensions (2518,32,32,32,3). I suspect that when it is doing a batch norm with size 32 after 64 conv layers, when it outputs (32,32,32,32,64), it is supposed to resize into (32*64, 32, 32, 32). Which would be the correct size for the batch norm.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 dayumgrill