'Select on second dimension on a 3D pytorch tensor with an array of indexes

I am kind of new with numpy and torch and I am struggling to understand what to me seems the most basic operations.

For instance, given this tensor:

A = tensor([[[6, 3, 8, 3],
         [1, 0, 9, 9]],

        [[4, 9, 4, 1],
         [8, 1, 3, 5]],

        [[9, 7, 5, 6],
         [3, 7, 8, 1]]])

And this other tensor:

B = tensor([1, 0, 1])

I would like to use B as indexes for A so that I get a 3 by 4 tensor that looks like this:

[[1, 0, 9, 9],
 [4, 9, 4, 1],         
 [3, 7, 8, 1]]

Thanks!



Solution 1:[1]

Ok, my mistake was to assume this:

A[:, B]

is equal to this:

A[[0, 1, 2], B]

Or more generally the solution I wanted is:

A[range(B.shape[0]), B]

Solution 2:[2]

Alternatively, you can use torch.gather:

>>> indexer = B.view(-1, 1, 1).expand(-1, -1, 4)
tensor([[[1, 1, 1, 1]],

        [[0, 0, 0, 0]],

        [[1, 1, 1, 1]]])

>>> A.gather(1, indexer).view(len(B), -1)
tensor([[1, 0, 9, 9],
        [4, 9, 4, 1],
        [3, 7, 8, 1]])

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Alex Pi
Solution 2 Ivan