'What's the proper way to update a leaf tensor's values (e.g. during the update step of gradient descent)
Toy Example
Consider this very simple implementation of gradient descent, whereby I attempt to fit a linear regression (mx + b) to some toy data.
import torch
# Make some data
torch.manual_seed(0)
X = torch.rand(35) * 5
Y = 3 * X + torch.rand(35)
# Initialize m and b
m = torch.rand(size=(1,), requires_grad=True)
b = torch.rand(size=(1,), requires_grad=True)
# Pass 1
yhat = X * m + b # Calculate yhat
loss = torch.sqrt(torch.mean((yhat - Y)**2)) # Calculate the loss
loss.backward() # Reverse mode differentiation
m = m - 0.1*m.grad # update m
b = b - 0.1*b.grad # update b
m.grad = None # zero out m gradient
b.grad = None # zero out b gradient
# Pass 2
yhat = X * m + b # Calculate yhat
loss = torch.sqrt(torch.mean((yhat - Y)**2)) # Calculate the loss
loss.backward() # Reverse mode differentiation
m = m - 0.1*m.grad # ERROR
The first pass works fine, but the second pass errors on the last line, m = m - 0.1*m.grad.
Error
/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1013: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:417.)
return self._grad
My understanding of why this happens is that, during Pass 1, this line
m = m - 0.1*m.grad
copies m into a brand new tensor (i.e. a totally separate block of memory). So, it goes from being a leaf tensor to a non-leaf tensor.
# Pass 1
...
print(f"{m.is_leaf}") # True
m = m - 0.1*m.grad
print(f"{m.is_leaf}") # False
So, how does one perform an update?
I've seen it mentioned that one could use something along the lines of m.data = m - 0.1*m.grad, but I haven't seen much discussion about this technique.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
