'retaining graph in pytorch grad
I have a neural network with specific architecture similar to a minimization problem. The code structure cannot be changed but I struggle to find a way to -
- compute grad wrt only certain arguments (
x) and not others (arg1,arg2) - keep grad's dependency on parameters in
scalar_net
x = torch.rand(batch_size, feature_size, requires_grad=True)
y = torch.rand(batch_size, feature_size)
for _ in range(10):
s = scalar_net.forward(x, arg1, arg2, ...) # shape = (batch_size,)
respect to x
gradient_wrt_x = ...
x += gradient_wrt_x
loss = torch.nn.MSELoss()(x, y)
loss.backward() # on scalar_net.parameters()
Anything would be helpful!
Edit:
Another important note. Although it is possible to compute gradient on s.sum with respect to x through torch.auto.grad, the calculation will be extremely memory inefficient and should be avoided.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
