'Compute gradients across two layers using gradients calculated from a previous layer using tf.gradients or tf.GradientTape

I want to use the gradients of one layer to calculate the gradients of the layer that comes before it.

My motivation for doing this is, when I tried to use model parallelism using tf.device, I found out that backpropagation has been running on CPU. The entire Backprop started running on a chosen tf.device only after I wrapped the call to GradientTape(when it computes the gradient) within the tf.device context manager. Since the model is split, I want the backprop of each partition to execute on the device where that partition is placed.

Ideally, I would like to find out a method with which this oversimplified pseudocode is possible.

with tf.device(device_3):
   grad_3 = tf.gradients(loss, trainable_vars_of_partition_3)
 
with tf.device(device_2):
   grad_2 = tf.gradients(grad_3, trainable_vars_of_partition_2)

with tf.device(device_1):
   grad_1 = tf.gradients(grad_2, trainable_vars_of_partition_1)

grads = concat(grad_1, grad_2, grad_3)

If something like this exists then I would be overjoyed if you could point me in the right direction.

Unfortunately, I could not find something as simple as this. The next best approach that I could think of was using the gradients of one layer to find the gradients of a layer that comes before it. Using chain rule and backpropagation, I feel that this should be possible.

I created this toy example, solving which is the first step towards the final goal.

Let's say we have a model with 3 dense layers without activations functions. X, Y as defined as follows:

x = tf.concat([tf.random.uniform([1, 10], minval=0, maxval=0.25),
               tf.random.uniform([1, 10], minval=0.25, maxval=0.5),
               tf.random.uniform([1, 10], minval=0.5, maxval=0.75),
               tf.random.uniform([1, 10], minval=0.75, maxval=1.),
                ], axis = 0)

y = tf.constant(0., shape=[4, 1])

d1 = tf.keras.layers.Dense(5, name='d1') 
d2 = tf.keras.layers.Dense(2, name='d2') 
d3 = tf.keras.layers.Dense(1, name='d3') 

I am using a tf.function in this toy example but an answer with eager mode enabled, using GradientTape will also be appreciated.

@tf.function
def tf_func(x, y, d1, d2, d3):
    # Using shortforms of these function helped the code look neater and more readable to me. 
    g = tf.gradients
    rs = tf.reduce_sum
    rm = tf.reduce_mean

    o1 = d1(x)
    o2 = d2(o1)
    o3 = d3(o2)

    l = tf.reduce_mean(tf.square(o3 - y))
    
    w3, w2, w1 = d3.trainable_variables, d2.trainable_variables, d1.trainable_variables

    tf.print('actual grads' + '=' * 80)

    dl_dw3 = g(l, w3)
    
    dl_dw2 = g(l, w2)
    tf.print('dl_dw2: \n', dl_dw2)

    dl_dw1 = g(l, w1)   

    tf.print()
    tf.print()
    
    tf.print('reference grads' + '=' * 80)
    dl_do1 = g(l, o1)
    dl_do2 = g(l, o2)
    tf.print('dl_do2: \n', dl_do2)
    dl_do3 = g(l, o3)

    dl_dw1 = g(l, w1)
    dl_dw2 = g(l, w2)
    dl_dw3 = g(l, w3)

    do3_o2 = g(o3, o2)
    do2_do1 = g(o2, o1)

    do3_w3 = g(o3, w3)
    do2_w2 = g(o2, w2)
    do1_w1 = g(o1, w1)


    tf.print('testing chain_rule method' + '=' * 80)
    
    # Added a 't' before derivatives to differentiate between ref_grads and grads obtained using chain rule

    tdl_do3 = g(l, o3) # same as ref_grads

    tdo3_dw3 = g(o3, w3) # same as ref_grads
    tdl_dw3 = [rm(tdl_do3) * tdo3_dw3[0], rm(tdl_do3) * tdo3_dw3[1]] # same as actual grads

    tdo3_do2 = g(o3, o2) # same as ref_grads

    tdl_do2 = tdo3_do2 * rm(tdl_do3, axis=0)  # same as ref_grads
    tf.print('tdl_do2: \n', tdl_do2)

    tdo2_dw2 = g(o2, w2) 
    tf.print('tdo2_dw2: \n', tdo2_dw2)
    
    tdl_dw2 = [tdo2_dw2[0] * rm(tdl_do2, axis=[1]), tdo2_dw2[1] * rm(tdl_do2, axis=[1])]
    tf.print('tdl_dw2: \n', tdl_dw2)

    return None 


tf_func(x, y, d1, d2, d3)

The output was:

actual grads================================================================================
dl_dw2: 
 [[[-3.04819393 -1.30051827]
 [5.02123785 2.14232159]
 [-0.260933906 -0.111328]
 [5.87596226 2.50699162]
 [1.9655633 0.838611722]], [-4.69162369 -2.0016911]]


reference grads================================================================================
dl_do2: 
 [[[-0.43842113 -0.187053293]
 [-0.889310718 -0.379426271]
 [-1.41650343 -0.604354143]
 [-1.94738865 -0.830857456]]]


testing chain_rule method================================================================================
tdl_do2: 
 [[[-0.43842113 -0.187053293]
  [-0.889310718 -0.379426271]
  [-1.41650343 -0.604354143]
  [-1.94738865 -0.830857456]]]
tdo2_dw2: 
 [[[2.10966444 2.10966444]
 [-3.48670244 -3.48670244]
 [0.22972326 0.22972326]
 [-3.95618558 -3.95618558]
 [-1.3790133 -1.3790133]], [4 4]]
tdl_dw2: 
 [[[-2.47443795 -1.05572414]
 [4.08957386 1.74482536]
 [-0.26944378 -0.114958748]
 [4.64023352 1.97976542]
 [1.61745286 0.690089643]], [[-4.69162369 -2.0016911]]]

For some reason, gradients wrt weights in tdl_dw2 and dl_dw2 differ slightly. Every value in tdl_dw2 is slightly less than dl_dw2 even though the gradients wrt biases are the same. I cannot figure out why.

The gradient of loss wrt to w3 is as expected.

I used tf.reduce_mean to replicate what tf.gradients was doing internally as far as I understand. Please correct me if I am wrong.

From Tensorflow's documentations:

gradients() adds ops to the graph to output the derivatives of ys with respect to xs. It returns a list of Tensor of length len(xs) where each tensor is the sum(dy/dx) for y in ys and for x in xs.

tf.gradients constructs symbolic derivatives of sum of ys w.r.t. x in xs.

Any guidance or help will be greatly appreciated, thank you.

Some Similar StackOverflow questions(there are many more):

  1. Compute gradients across two models
  2. Is it possible to acquire an intermediate gradient? (Tensorflow)
  3. Breaking TensorFlow gradient calculation into two (or more) parts

Here is a colab notebook with the code: https://colab.research.google.com/drive/1034hu6Zo766-spKu5qfeG4c2Yv2v-DGM?usp=sharing



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source