'Replace every element in each channel if the channel has any value greater than 0 without looping

I have a batch tensor of size (4, 100, 56, 56), where some channels have a certain values in it, and some only have all zeros. I wanted to make each elements in the channels has any value greater than 0 to be of 100, whereas if it has all zeros, it should be made to has 1 in each element. Any idea how to achieve this without looping?

t = torch.zeros((4, 100, 56, 56))
t[:, 5, 15:20, 15:20] = 0.07
new_t = torch.ones((4, 100, 56, 56))

for b in range(t.size(0))
    for c in range(t.size(1)):
        if t[b, c, :,:].max() > 0:
             new_t[b, c, :, :] = 100

My code above is inefficient for large batches and channels, and it create memory overhead due to new_t, is there a way to use view() or similar functions to achieve this?



Solution 1:[1]

You can perform the following:

mask = torch.any(t.flatten(2, 3) > 0., dim=2)
t[mask] = 100.  # or t[mask] *= 100. for differentiability

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 aretor