'How do I select values from a pytorch tensor by index?
I was struggling to figure out how to go about this class problem using pytorch. The question is "select for all i,j the values x[i,j,k] where ind[i,j] = k in a tensor,
the tensor should have shape (10,50)"
ind = torch.randint(50,(10,50))
x = torch.randn(10,50,50)
Could I do this using torch.scatter or .gather?
Solution 1:[1]
You can use torch.gather, you just have to expand the dims of your indices:
y = torch.gather(x,2,ind[:,:,None]).squeeze(2)
assert y[0] == x[0,0,ind[0][0]]
This is because the indices must be the same dimensionality as the input tensor.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | jhso |
