'Replace every element in each channel if the channel has any value greater than 0 without looping
I have a batch tensor of size (4, 100, 56, 56), where some channels have a certain values in it, and some only have all zeros. I wanted to make each elements in the channels has any value greater than 0 to be of 100, whereas if it has all zeros, it should be made to has 1 in each element. Any idea how to achieve this without looping?
t = torch.zeros((4, 100, 56, 56))
t[:, 5, 15:20, 15:20] = 0.07
new_t = torch.ones((4, 100, 56, 56))
for b in range(t.size(0))
for c in range(t.size(1)):
if t[b, c, :,:].max() > 0:
new_t[b, c, :, :] = 100
My code above is inefficient for large batches and channels, and it create memory overhead due to new_t, is there a way to use view() or similar functions to achieve this?
Solution 1:[1]
You can perform the following:
mask = torch.any(t.flatten(2, 3) > 0., dim=2)
t[mask] = 100. # or t[mask] *= 100. for differentiability
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | aretor |
