'In TensorFlow, how to write an if-else statement?

Why both the if-else statements executed? "y" and "n" are printed out when executing the following code:

        if tf.reduce_all(tf.math.pow(s1, 4) == 0):
            print("y") 
        else:
            print("n") 

Full code:

class MAGOptimizer(keras.optimizers.Optimizer):
    def __init__(self, learning_rate=0.01, name="SGOptimizer", **kwargs):
        """Call super().__init__() and use _set_hyper() to store hyperparameters"""
        super().__init__(name, **kwargs)
        self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) # handle lr=learning_rate
        self._is_first = True
    
    def _create_slots(self, var_list):
        for var in var_list:
            self.add_slot(var, "pv") #previous variable i.e. weight or bias
        for var in var_list:
            self.add_slot(var, "pg") #previous gradient
  
    @tf.function
    def _resource_apply_dense(self, grad, var):
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype) # handle learning rate decay

        new_var_m = tf.math.subtract(var, tf.multiply(grad, lr_t))
        
        pv_var = self.get_slot(var, "pv")
        pg_var = self.get_slot(var, "pg")
        
        if self._is_first:
            self._is_first = False
            new_var = new_var_m
          
        else:

            new_var = tf.math.subtract(var, tf.multiply(grad, lr_t))

        s1 = tf.math.subtract(var, pv_var)

        if tf.reduce_all(tf.math.pow(s1, 4) == 0):
            print("y") 
        else:
            print("n") 

            
        pv_var.assign(var)
        pg_var.assign(grad)
        var.assign(new_var)


    def _resource_apply_sparse(self, grad, var):
        raise NotImplementedError

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
        }


    def _resource_apply_sparse(self, grad, var):
        raise NotImplementedError

    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "decay": self._serialize_hyperparameter("decay"),
            "momentum": self._serialize_hyperparameter("momentum"),
        }
    
model = keras.models.Sequential([keras.layers.Dense(1, input_shape=[8])])
model.add(keras.layers.Dense(2))
model.add(keras.layers.Dense(1))

model.compile(loss="mse", optimizer=MAGOptimizer(learning_rate=0.001))
model.fit(X_train_scaled, y_train, epochs=2)

The output is:

Train on 11610 samples
Epoch 1/2
y
n
y
n
y
n
y
n
y
n
WARNING:tensorflow:5 out of the last 5 calls to <function MAdaGradOptimizer._resource_apply_dense at 0x000001D4D8760828> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.
11610/11610 [==============================] - 3s 235us/sample - loss: 2.9868
Epoch 2/2
11610/11610 [==============================] - 1s 89us/sample - loss: 1.3655
<tensorflow.python.keras.callbacks.History at 0x1d4d8c3f908>

At runtime, does it execute for one branch or both branches?



Solution 1:[1]

This is documented behavior, which occurs when using print instead of tf.print in graph mode. See this for more details:

Side effects, like printing, appending to lists, and mutating globals, can behave unexpectedly inside a Function, sometimes executing twice or not all. They only happen the first time you call a Function with a set of inputs. Afterwards, the traced tf.Graph is reexecuted, without executing the Python code.

Here is an example:

@tf.function
def test(s1):
  if tf.reduce_all(tf.math.pow(s1, 4) == 0):
      print("y") 
  else:
      print("n")
  return s1

test(tf.random.normal((5, 5)))
y
n
@tf.function
def test(s1):
  if tf.reduce_all(tf.math.pow(s1, 4) == 0):
     tf.print("y") 
  else:
     tf.print("n")
  return s1

test(tf.random.normal((5, 5)))
n

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1