'Why does torch.scatter requires a smaller shape for indices than values?

A similar question was already asked here, but I think the solution is not suited for my case.

I just wonder why it is not possible to do a torch.scatter operation, where my index tensor is bigger than my value tensor. In my case I have duplicate indices, e.g. the following value tensor a and the index tensor idx:

a = torch.tensor([[0, 1, 0, 0],
                  [0, 0, 1, 0]])

idx = torch.tensor([[1, 1, 2, 3, 3],
                    [0, 0, 1, 2, 2]])

a.scatter(-1, idx, 1) returns:

RuntimeError: Expected index [2, 5] to be smaller than self [2, 4] apart from dimension 1 and to be smaller size than src [2, 4]

Is there another way to achieve this?



Solution 1:[1]

Not a solution, but a workaround:

a = torch.tensor([[0, 1, 0, 0],
                  [0, 0, 1, 0]])

idx = torch.tensor([[1, 1, 2, 3, 3],
                    [0, 0, 1, 2, 2]])

rows = torch.arange(0, a.size(0))[:,None]
n_col = idx.size(1)
a[rows.repeat(1, n_col), idx] = 1

rows.repeat(1, n_col) gives the row index to the corresponding column index in idx.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Christian