'Select on second dimension on a 3D pytorch tensor with an array of indexes
I am kind of new with numpy and torch and I am struggling to understand what to me seems the most basic operations.
For instance, given this tensor:
A = tensor([[[6, 3, 8, 3],
[1, 0, 9, 9]],
[[4, 9, 4, 1],
[8, 1, 3, 5]],
[[9, 7, 5, 6],
[3, 7, 8, 1]]])
And this other tensor:
B = tensor([1, 0, 1])
I would like to use B as indexes for A so that I get a 3 by 4 tensor that looks like this:
[[1, 0, 9, 9],
[4, 9, 4, 1],
[3, 7, 8, 1]]
Thanks!
Solution 1:[1]
Ok, my mistake was to assume this:
A[:, B]
is equal to this:
A[[0, 1, 2], B]
Or more generally the solution I wanted is:
A[range(B.shape[0]), B]
Solution 2:[2]
Alternatively, you can use torch.gather:
>>> indexer = B.view(-1, 1, 1).expand(-1, -1, 4)
tensor([[[1, 1, 1, 1]],
[[0, 0, 0, 0]],
[[1, 1, 1, 1]]])
>>> A.gather(1, indexer).view(len(B), -1)
tensor([[1, 0, 9, 9],
[4, 9, 4, 1],
[3, 7, 8, 1]])
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Alex Pi |
| Solution 2 | Ivan |
