'Tensorflow2: Getting gradient of output w.r.t to intermediate layer activation
I am working on a model from which I need to extract the gradient of output w.r.t intermediate layer activations. A minimum working example is given below
import tensorflow as tf
from tensorflow.keras.layers import (
BatchNormalization,
Conv2D,
ZeroPadding2D,
LeakyReLU,
Input
)
from tensorflow.keras.regularizers import l2
def convolutional(input_layer, filter_shape, downsample=False, activate=True, bn=True):
"""
This is from Tensorflow2 YOLO model here
https://github.com/pythonlessons/TensorFlow-2.x-YOLOv3/blob/master/yolov3/yolov3.py#L32
"""
if downsample:
input_layer = ZeroPadding2D(((1, 0), (1, 0)))(input_layer)
padding = 'valid'
strides = 2
else:
strides = 1
padding = 'same'
conv = Conv2D(filters=filter_shape[-1],
kernel_size=filter_shape[0],
strides=strides,
padding=padding,
use_bias=not bn,
kernel_regularizer=l2(0.0005),
kernel_initializer=tf.random_normal_initializer(stddev=0.01),
bias_initializer=tf.constant_initializer(0.))(input_layer)
if bn:
conv = BatchNormalization()(conv)
if activate == True:
conv = LeakyReLU(alpha=0.1)(conv)
return conv
def main():
input_size = (32, 32, 3)
in_layer = Input(input_size)
x = convolutional(in_layer, (3, 3, 3, 5))
x = convolutional(x, (3, 3, 5, 5), downsample=True)
x = convolutional(x, (3, 3, 5, 5), downsample=True)
out_layer = convolutional(x, (3, 3, 5, 1))
model = tf.keras.Model(in_layer, out_layer)
model.summary()
x1 = tf.random.uniform((1, *input_size), minval=0, maxval=1.0) # sample input
l1 = model.get_layer('conv2d_2').output # Get gradient with respect to this layer
with tf.GradientTape() as tape:
tape.watch(x1)
y1 = model(x1)
l = tf.reduce_mean(y1) # Summary of output
grads = tape.gradient(l, l1) # Get gradient of l w.r.t yo l1 activations
print(grads) # This is None :(
if __name__=="__main__":
main()
In the above code, I wanted to get the gradient of l w.r.t to the activations from conv2d_2 Conv2D layer.
When I run the code above, grads is None. How can I fix this and get the correct gradients.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
