'Find index where a sub-tensor does not equal to a given tensor in Pytorch

I have a tensor, for example,

a = [[15,30,0,2], [-1,-1,-1,-1], [10, 20, 40, 60], [-1,-1,-1,-1]]

which has the shape (4,4). How can I find the index where a specific sub-tensor

[-1,-1,-1,-1] 

that doesn't appear using PyTorch. The expected output I want to get is

[0,2]


Solution 1:[1]

You can compare the elements for each row of the tensor using torch.any(), and then use .nonzero() and .flatten() to generate the indices:

torch.any(a != torch.Tensor([-1, -1, -1, -1]), axis=1).nonzero().flatten()

For example,

import torch

a = torch.Tensor([[15,30,0,2], [-1,-1,-1,-1], [10, 20, 40, 60], [-1,-1,-1,-1]])
result = torch.any(a != torch.Tensor([-1, -1, -1, -1]), axis=1).nonzero().flatten()
print(result)

outputs:

tensor([0, 2])

Solution 2:[2]

You can also use where or nonzero:

a = torch.Tensor([[15,30,0,2], [-1,-1,-1,-1], [10, 20, 40, 60], [-1,-1,-1,-1]])

b = torch.Tensor([-1,-1,-1,-1])

result = torch.where(a != b)[0].unique()

result = torch.nonzero(a != b, as_tuple=True)[0].unique()

print(result)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 BrokenBenchmark
Solution 2 Phoenix