'Why does torch.scatter requires a smaller shape for indices than values?
A similar question was already asked here, but I think the solution is not suited for my case.
I just wonder why it is not possible to do a torch.scatter operation, where my index tensor is bigger than my value tensor. In my case I have duplicate indices, e.g. the following value tensor a and the index tensor idx:
a = torch.tensor([[0, 1, 0, 0],
[0, 0, 1, 0]])
idx = torch.tensor([[1, 1, 2, 3, 3],
[0, 0, 1, 2, 2]])
a.scatter(-1, idx, 1) returns:
RuntimeError: Expected index [2, 5] to be smaller than self [2, 4] apart from dimension 1 and to be smaller size than src [2, 4]
Is there another way to achieve this?
Solution 1:[1]
Not a solution, but a workaround:
a = torch.tensor([[0, 1, 0, 0],
[0, 0, 1, 0]])
idx = torch.tensor([[1, 1, 2, 3, 3],
[0, 0, 1, 2, 2]])
rows = torch.arange(0, a.size(0))[:,None]
n_col = idx.size(1)
a[rows.repeat(1, n_col), idx] = 1
rows.repeat(1, n_col) gives the row index to the corresponding column index in idx.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Christian |
