'tf.keras.BatchNormalization giving unexpected output
import tensorflow as tf
tf.enable_eager_execution()
print(tf.keras.layers.BatchNormalization()(tf.convert_to_tensor([[5.0, 70.0], [5.0, 60.0]])))
print(tf.contrib.layers.batch_norm(tf.convert_to_tensor([[5.0, 70.0], [5.0, 60.0]])))"
The output of the above code (in Tensorflow 1.15) is:
tf.Tensor([[ 4.99 69.96] [ 4.99 59.97]], shape=(2, 2), dtype=float32)
tf.Tensor([[ 0. 0.99998] [ 0. -0.99998]], shape=(2, 2), dtype=float32)
My problem is why the same function is giving completely different outputs. I also played with some of the parameters of the functions but the result was the same. For me, the second output is what I want. Also, pytorch's batchnorm also gives the same output as second one. So I'm thinking its the issue with keras.
Know how to fix batchnorm in keras?
Solution 1:[1]
Batch Normalization layer has different behavior in training vs. inferencing:
During training (i.e. when using
fit()or when calling the layer/model with the argumenttraining=True), the layer normalizes its output using the mean and standard deviation of the current batch of inputs.During inference (i.e. when using
evaluate()orpredict()or when calling the layer/model with the argumenttraining=False(which is the default), the layer normalizes its output using a moving average of the mean and standard deviation of the batches it has seen during training.
So, the first result is due to default training=False and the second is due to default is_training=True.
If you want the same result you may try:
x = tf.convert_to_tensor([[5.0, 70.0], [5.0, 60.0]])
print(tf.keras.layers.BatchNormalization()(x, training=True).numpy().tolist())
print(tf.contrib.layers.batch_norm(x).numpy().tolist())
#output
#[[0.0, 0.9999799728393555], [0.0, -0.9999799728393555]]
#[[0.0, 0.9999799728393555], [0.0, -0.9999799728393555]]
or
x = tf.convert_to_tensor([[5.0, 70.0], [5.0, 60.0]])
print(tf.keras.layers.BatchNormalization()(x).numpy().tolist())
print(tf.contrib.layers.batch_norm(x, is_training=False).numpy().tolist())
#output
#[[4.997501850128174, 69.96502685546875], [4.997501850128174, 59.97002410888672]]
#[[4.997501850128174, 69.96502685546875], [4.997501850128174, 59.97002410888672]]
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 |
