'Pytorch compute the mean of a 2D tensor at specific rows given a condition

Say I have a tensor

tensor([[0, 1, 2, 2],
        [2, 4, 2, 4],
        [3, 4, 3, 1],
        [4, 4, 4, 3]])

and a tensor of indices

tensor([[1],
        [2],
        [1],
        [3]])

I want to compute the mean where the indices values match. In this case I want the mean of row 1 and 3 so the final output would be

tensor([[1.5, 2.5, 2.5, 1.5],
        [2,   4,   2,   4],
        [4,   4,   4,   3]])


Solution 1:[1]

You can use torch.scatter_reduce to compute sums. To compute averages we have to use it twice, one for computing sums, and one for counting the summands, such that we can divide by the number of counts. One detail though is that since pytorch uses 0-based indexing we need to subtract 1 from those values:

import torch
a = torch.tensor([[0, 1, 2, 2], [2, 4, 2, 4], [3, 4, 3, 1], [4, 4, 4, 3]])
b = torch.tensor([[1], [2], [1], [3]])
cc = torch.tensor([[1.5, 5.2, 5.2, 1.5], [2,   4,   2,   4], [4,   4,   4,   3]]) # goal

c = torch.scatter_reduce(
    a.to(float),
    0,
    torch.broadcast_to(b, a.shape) - 1,
    reduce='mean'
)
print(c)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 flawr