'How do I make one patch as a variable with pytorch's unfold method?

I divided an image of size 224 x 224 into patches and was able to change the pixel values of one patch with the following code snippet. Each patch is 16x16, so I have 196 patches all together.

uf = nnf.unfold(img, kernel_size=16, stride=16, padding=0)
uf[..., 130] = 0
f = nnf.fold(uf, img.shape[-2:], kernel_size=16, stride=16, padding=0)

uf[..., 130] = 0 change all pixel values of the 130th patch to 0. The above code simply splits the image into patches, edit the pixels of the specified patch and combine the patches back to form the complete image (the pixel manipulation is noticeable when visualized)

I need the patch as a variable. I expect patch = uf[..., 130] to give me a shape of 3x16x16 corresponding to the shape of the 130th patch but got torch.Size([1, 768]) as the output.

How do I have one patch as a variable, get the correct output shape and be able to add the patch back to the fold method for combination?



Solution 1:[1]

unfold returns a bidimensional matrix where the columns indicize each patch and the rows indicize the patch positions across the feature maps and the channels (indeed 768 == 3 * 16 * 16).

If you just need the original shape back you can use img.reshape(1, 3, 224, 224). You can also select the desired patch in the original image by converting 130 to double-index notation:

h = ceil(224 / 16)
i, j = 130 // h, 130 % h
# indices to select the patch 130
idx, idy = range(i * 16, (i + 1) * 16), range(j * 16, (j + 1) * 16)

If you want to be sure to preserve gradients in the result you can use boolean masking, as follows:

import torch
from math import ceil

img = torch.rand(1, 3, 224, 224, requires_grad=True)

h = ceil(224 / 16)
i, j = 130 // h, 130 % h

mask = torch.ones(224, 224)
mask[i * 16 : (i + 1) * 16, j * 16 : (j + 1) * 16] = 0

img = img * mask  # img.requires_grad is True

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1