'Pytorch compute the mean of a 2D tensor at specific rows given a condition
Say I have a tensor
tensor([[0, 1, 2, 2],
[2, 4, 2, 4],
[3, 4, 3, 1],
[4, 4, 4, 3]])
and a tensor of indices
tensor([[1],
[2],
[1],
[3]])
I want to compute the mean where the indices values match. In this case I want the mean of row 1 and 3 so the final output would be
tensor([[1.5, 2.5, 2.5, 1.5],
[2, 4, 2, 4],
[4, 4, 4, 3]])
Solution 1:[1]
You can use torch.scatter_reduce to compute sums. To compute averages we have to use it twice, one for computing sums, and one for counting the summands, such that we can divide by the number of counts. One detail though is that since pytorch uses 0-based indexing we need to subtract 1 from those values:
import torch
a = torch.tensor([[0, 1, 2, 2], [2, 4, 2, 4], [3, 4, 3, 1], [4, 4, 4, 3]])
b = torch.tensor([[1], [2], [1], [3]])
cc = torch.tensor([[1.5, 5.2, 5.2, 1.5], [2, 4, 2, 4], [4, 4, 4, 3]]) # goal
c = torch.scatter_reduce(
a.to(float),
0,
torch.broadcast_to(b, a.shape) - 1,
reduce='mean'
)
print(c)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | flawr |
