'extract index of image patch which have white color

I have a batch of image with shape x.shape = [32, 256, 96, 3]. Image sample looks like this

enter image description here

I have extracted patches of size patch.shape = [16, 16, 3] and flattened it to last dimension and at dim=1 i've all the patches giving me a final size of x_patched.shape = [32, 96, 768] where 96 = 16(num_patch_height) *6(num_patch_width)

In the image shown above, the white line in middle represents NaN values. Now i want to extract those patches only which have these NaN values. (preferably patch index from x_patched such as from the given 96 patch, patch index 3,4,9,10 ... are the ones that have NaN values).

My Code:

B, C, H, W = x.shape #32, 3, 256, 96 & range of x is b/w 0-255
num_patches_w, num_patches_h = (W // 16), (H // 16) #6, 16
x_patched = x.view(-1, 768, num_patches_h, num_patches_w) #32, 768, 16, 6
x_patched = x_.flatten(2).transpose(1, 2) #32, 96, 768


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source