'The interaction between two networks in a Tensorflow custom loss function

Assume you have two Tensorflow models model_A and model_B and the training loop looks something like this,

with tf.GradientTape() as tape:

     output_A = model_A(input)
     output_B = model_B(input)

     loss = loss_fn(output_A, output_B, true_output_A, true_output_B)

grads = tape.gradient(loss, model_A.trainable_variables)
optimizer.apply_gradients(zip(grads, model_A.trainable_variables))


and define the loss function as,

def loss_fn(output_A, output_B, true_output_A, true_output_B)

     loss = (output_A + output_B) - (true_output_A + true_output_B)

     return loss

The loss function that is being used to update model_A has the output of another network (output_B). How does Tensorflow handle this situation?

Does it use the weights of model_B when computing the gradient? or does it deal with output_B as a constant and not try to trace its origins?



Solution 1:[1]

It won't use model_B weights, only model_A weights will be updated.

For example:

import tensorflow as tf

# Model1
cnnin = tf.keras.layers.Input(shape=(10, 10, 1))
x = tf.keras.layers.Conv2D(8, 4)(cnnin)
x = tf.keras.layers.Conv2D(16, 4)(x)
x = tf.keras.layers.Conv2D(32, 2)(x)
x = tf.keras.layers.Conv2D(64, 2)(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(4)(x)
x = tf.keras.layers.Dense(4, activation="relu")(x)
cnnout = tf.keras.layers.Dense(1, activation="linear")(x)


# Model 2
mlpin= tf.keras.layers.Input(shape=(10, 10, 1), name="mlp_input")
z= tf.keras.layers.Dense(4, activation="sigmoid")(mlpin)
z= tf.keras.layers.Dense(4, activation = "softmax")(z)
z = tf.keras.layers.Flatten()(z)
z = tf.keras.layers.Dense(4)(z)
mlpout   = tf.keras.layers.Dense(1, activation="linear")(z)

Loss function

def loss_fn(output_A, output_B, true_output_A, true_output_B):
    output_A = tf.reshape(output_A, [-1])
    output_B = tf.reshape(output_B, [-1])
    pred = tf.reduce_sum(output_A + output_B)
    inputs = tf.reduce_sum(true_output_A+ true_output_B)
    loss = inputs-pred
    return loss

Customize what happens in Model.fit

loss_tracker = tf.keras.metrics.Mean(name = "custom_loss")
class TestModel(tf.keras.Model):
    def __init__(self, model1, model2):
        super(TestModel, self).__init__()
        self.model1 = model1
        self.model2 = model2
    def compile(self, optimizer):
        super(TestModel, self).compile()
        self.optimizer = optimizer
    def train_step(self, data):
        x, (y1, y2) = data
        with tf.GradientTape() as tape:
            ypred1 = self.model1([x], training = True)
            ypred2 = self.model2([x], training = True)
            loss_value = loss_fn(ypred1, ypred2, y1,y2)
        # Compute gradients
        trainable_vars = self.model1.trainable_variables
        gradients = tape.gradient(loss_value, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        loss_tracker.update_state(loss_value)
        return {"loss": loss_tracker.result()}

Define model1 and model2 and save them, so you can check the weights after training

model1= tf.keras.models.Model(cnnin, cnnout, name="model1")
model2 = tf.keras.models.Model(mlpin, mlpout, name="model2")

model1.save('test_model1.h5')
model2.save('test_model2.h5')

import numpy as np
x = np.random.rand(6, 10,10,1)

y1 = np.random.rand(6,1)
y2 = np.random.rand(6,1)
trainable_model = TestModel(model1, model2)

trainable_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate = 0.0001))
trainable_model.fit(x=x, y = (y1, y2), epochs=10)

Gives the following output:

Epoch 1/10
1/1 [==============================] - 0s 375ms/step - loss: 7.9465
Epoch 2/10
1/1 [==============================] - 0s 6ms/step - loss: 7.8509
Epoch 3/10
1/1 [==============================] - 0s 6ms/step - loss: 7.7547
Epoch 4/10
1/1 [==============================] - 0s 6ms/step - loss: 7.6577
Epoch 5/10
1/1 [==============================] - 0s 5ms/step - loss: 7.5600
Epoch 6/10
1/1 [==============================] - 0s 4ms/step - loss: 7.4608
Epoch 7/10
1/1 [==============================] - 0s 4ms/step - loss: 7.3574
Epoch 8/10
1/1 [==============================] - 0s 6ms/step - loss: 7.2514
Epoch 9/10
1/1 [==============================] - 0s 5ms/step - loss: 7.1429
Epoch 10/10
1/1 [==============================] - 0s 5ms/step - loss: 7.0323

Then load saved models and check the trainable_weights:

test_model1 = tf.keras.models.load_model('test_model1.h5')
test_model2 = tf.keras.models.load_model('test_model2.h5')

Compare model1 trainable_weights before and after training (they should all change):

model1_weights = [i for i in model1.trainable_weights]
for i in range(len(model1_weights)):
    print(np.array_equal(model1.trainable_weights[i], test_model1.trainable_weights[i]))

Outputs:

False
False
False
False
False
False
False
False
False
False
False
False
False
False

Compare model2 trainable_weights before and after training (they should all be the same):

model2_weights = [i for i in model2.trainable_weights]
for i in range(len(model2_weights)):
    print(np.array_equal(model2.trainable_weights[i], test_model2.trainable_weights[i]))

Outputs:

True
True
True
True
True
True
True
True

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 delirium78