'How do I select values from a pytorch tensor by index?

I was struggling to figure out how to go about this class problem using pytorch. The question is "select for all i,j the values x[i,j,k] where ind[i,j] = k in a tensor,
the tensor should have shape (10,50)"

ind = torch.randint(50,(10,50))
x = torch.randn(10,50,50)

Could I do this using torch.scatter or .gather?



Solution 1:[1]

You can use torch.gather, you just have to expand the dims of your indices:

y = torch.gather(x,2,ind[:,:,None]).squeeze(2)
assert y[0] == x[0,0,ind[0][0]]

This is because the indices must be the same dimensionality as the input tensor.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 jhso