'Find index where a sub-tensor does not equal to a given tensor in Pytorch
I have a tensor, for example,
a = [[15,30,0,2], [-1,-1,-1,-1], [10, 20, 40, 60], [-1,-1,-1,-1]]
which has the shape (4,4).
How can I find the index where a specific sub-tensor
[-1,-1,-1,-1]
that doesn't appear using PyTorch. The expected output I want to get is
[0,2]
Solution 1:[1]
You can compare the elements for each row of the tensor using torch.any(), and then use .nonzero() and .flatten() to generate the indices:
torch.any(a != torch.Tensor([-1, -1, -1, -1]), axis=1).nonzero().flatten()
For example,
import torch
a = torch.Tensor([[15,30,0,2], [-1,-1,-1,-1], [10, 20, 40, 60], [-1,-1,-1,-1]])
result = torch.any(a != torch.Tensor([-1, -1, -1, -1]), axis=1).nonzero().flatten()
print(result)
outputs:
tensor([0, 2])
Solution 2:[2]
You can also use where or nonzero:
a = torch.Tensor([[15,30,0,2], [-1,-1,-1,-1], [10, 20, 40, 60], [-1,-1,-1,-1]])
b = torch.Tensor([-1,-1,-1,-1])
result = torch.where(a != b)[0].unique()
result = torch.nonzero(a != b, as_tuple=True)[0].unique()
print(result)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | BrokenBenchmark |
| Solution 2 | Phoenix |
