'How to get unique elements and their firstly appeared indices of a pytorch tensor?
Assume a 2*X(always 2 rows) pytorch tensor:
A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
[43., 33., 43., 76., 33., 76., 55., 55., 55.]])
torch.unique(A, dim=1) will return:
tensor([[ 1., 2., 2., 3., 3., 4.],
[43., 33., 43., 33., 76., 55.]])
But I also need the indices of every unique elements where they firstly appear in original input. In this case, indices should be like:
tensor([0, 1, 2, 3, 4, 6])
# Explanation
# A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
# [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
# (0) (1) (2) (3) (4) (6)
It's complex for me because the second row of tensor A may not be nicely sorted:
A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
[43., 33., 43., 76., 33., 76., 55., 55., 55.]])
^ ^
Is there a simple and efficient method to get the desired indices?
P.S. It may be useful that the first row of the tensor is always in ascending order.
Solution 1:[1]
One possible way to gain such indicies:
unique, idx, counts = torch.unique(A, dim=1, sorted=True, return_inverse=True, return_counts=True)
_, ind_sorted = torch.sort(idx, stable=True)
cum_sum = counts.cumsum(0)
cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]))
first_indicies = ind_sorted[cum_sum]
For tensor A in snippet above:
print(first_indicies)
# tensor([0, 1, 2, 4, 3, 6])
Note that unique in this case is equal to:
tensor([[ 1., 2., 2., 3., 3., 4.],
[43., 33., 43., 33., 76., 55.]])
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | draw |
