'retaining graph in pytorch grad

I have a neural network with specific architecture similar to a minimization problem. The code structure cannot be changed but I struggle to find a way to -

  • compute grad wrt only certain arguments (x) and not others (arg1, arg2)
  • keep grad's dependency on parameters in scalar_net
x = torch.rand(batch_size, feature_size, requires_grad=True)
y = torch.rand(batch_size, feature_size)

for _ in range(10):
    s = scalar_net.forward(x, arg1, arg2, ...)  # shape = (batch_size,)
respect to x
    gradient_wrt_x = ...
    x += gradient_wrt_x

loss = torch.nn.MSELoss()(x, y)
loss.backward()  # on scalar_net.parameters()

Anything would be helpful!

Edit:

Another important note. Although it is possible to compute gradient on s.sum with respect to x through torch.auto.grad, the calculation will be extremely memory inefficient and should be avoided.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source