'TensorFlow Batch Hessian

I'm building a neural network that must approximate some multivariate function, say f(x). The loss function is defined as how close the second derivative of the network is to the function f. To do this, I must compute the Hessian of f(x). I wrote a custom TensorFlow model that kind of looks like this

class ApproximateModel(tf.keras.Model):
    @tf.function
    def f_true_hessian(x: tf.Tensor) -> tf.Tensor:
        # Some function that should return the actual Hessian
        return x
    
    def train_step(self, data):
        # tf.get_shape(x) -> (batch_size, dimension_x)
        x = data[0]
        
        # Calculate loss
        with tf.GradientTape() as second_tape:
            with tf.GradientTape() as first_tape:
                first_tape.watch(x)
                second_tape.watch(x)
                
                f = self(x, training=True)
            
            f_x = first_tape.gradient(f, x)
            second_tape.watch(f_x)
        
        f_jacobian = second_tape.jacobian(f_x, x)
        # tf.get_shape(f) -> (batch_size, dimension_x, batch_size, dimension_x)
        
        # I want to get (batch_size, dimension_x, dimension_x) somehow..
        loss = tf.math.reduce_mean(tf.math.square(tf.reduce_sum(f_jacobian, axis=[1, 2]) - self.f_true_hessian(x))))
        return loss

For the interested reader, the application of this type of network is to approximate PDE's as in here.

The code above works well in case I don't have a batch size. I can't figure out how to get the Hessian in case I have a batch of samples of x. How do I get my desired output, where only the Hessian of dimension_x is computed and the batch_size is omitted?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source