'What is the difference between sample() and rsample()?
When I sample from a distribution in PyTorch, both sample and rsample appear to give similar results:
import torch, seaborn as sns
x = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0]))
When should I use sample(), and when should I use rsample()?
Solution 1:[1]
Using rsample allows for pathwise derivatives:
The other way to implement these stochastic/policy gradients would be to use the reparameterization trick from the
rsample()method, where the parameterized random variable can be constructed via a parameterized deterministic function of a parameter-free random variable. The reparameterized sample therefore becomes differentiable.
Solution 2:[2]
sample(): random sampling from the probability distribution. So, we cannot backpropagate, because it is random! (the computation graph is cut off).
See the source code of sample in torch.distributions.normal.Normal:
def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
with torch.no_grad():
return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
torch.normal returns a tensor of random numbers. Also, torch.no_grad() context prevents the computation graph from growing any further.
You see, we cannot backprop. The returned tensor of sample() contains just some numbers, not the whole computational graph.
So, what is rsample()?
By using rsample, we can backpropagate, because it keeps the computation graph alive.
How? By putting the randomness aside in a separate parameter. This is called the "reparameterization trick".
rsample: sampling using reparameterization trick.
There is eps in the source code:
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
return self.loc + eps * self.scale
# `self.loc` is the mean and `self.scale` is the standard deviation.
eps is the separate parameter responsible for the randomness of the sampling.
Look at the return: mean + eps * standard deviation
eps does not depend on the parameters you want to differentiate with respect to.
So, now you can freely backpropagate(=differentiate) because eps does not change when the parameters change.
(If we change the parameters, the distribution of the reparameterized samples does change because self.loc and self.scale change, but the distribution of the eps does not change.)
Note that the randomness of the sampling comes from the random sampling of the eps. There is no randomness in the computation graph itself. Once eps is chosen, it is fixed. (the distribution of the elements of the eps is fixed, after they are sampled.)
For example, in an implementation of the SAC(Soft Actor-Critic) algorithm in reinforcement learning, eps may consist of elements corresponding to a single minibatch of actions (and one action may consist of many elements).
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Shai |
| Solution 2 |


