'spatial domain convolution not equal to frequency domain multiplication using pytorch
I want to verify if 2D convolution in spatial domain is really a multiplication in frequency domain, so I used pytorch to implement convolution of an image with a 3×3 kernel (both real). Then I transformed both the image and the kernel into frequency domain, multiplied them, transformed the result back to spatial domain. Here's the result:
When the kernal is even or odd (i.e. pure real or pure imaginary in frequency domain), results of the two transforms seem to match well. I use min and max of both results to evaluate because I'm not sure if some margin alignment problems may affect direct difference. Here are three runs of either even and odd kernel:
# Even Kernel
min max with s-domain conv: -0.03659552335739136 4.378755569458008
min max with f-domain mul: -0.0365956649184227 4.378755569458008
min max with s-domain conv: -1.2673343420028687 2.397951126098633
min max with f-domain mul: -1.2673344612121582 2.397951126098633
min max with s-domain conv: -8.185677528381348 0.22980886697769165
min max with f-domain mul: -8.185677528381348 0.22980868816375732
# Odd Kernel
min max with s-domain conv: -1.6630988121032715 1.6592578887939453
min max with f-domain mul: -1.663098692893982 1.6592577695846558
min max with s-domain conv: -3.483165979385376 3.4751217365264893
min max with f-domain mul: -3.483165979385376 3.475121259689331
min max with s-domain conv: -1.7972984313964844 1.7931475639343262
min max with f-domain mul: -1.7972984313964844 1.7931475639343262
But if I use neither even or odd kernel, the difference is just at another level:
min max with s-domain conv: -2.3028392791748047 1.675748348236084
min max with f-domain mul: -2.5289478302001953 1.4919483661651611
min max with s-domain conv: -1.1227827072143555 3.0336122512817383
min max with f-domain mul: -1.1954418420791626 2.9853036403656006
min max with s-domain conv: -1.6867876052856445 5.575590133666992
min max with f-domain mul: -1.6832940578460693 5.688591957092285
I was wondering if this arises from precision in floating point. But I tried torch's complex128, it wasn't any better. Is there something wrong with my implementation? Or it is inevitable due to calculation with complex numbers?
Here's a simplified version of my code that could produce this result.
import torch.nn.functional as F
import torch.fft as fft
import torch, cv2
img = cv2.imread('test.png', 0)
x = torch.as_tensor(img).unsqueeze(0)/255
k = torch.randn(1, 1, 3, 3)
for i in range(k.size(0)):
for j in range(k.size(1)):
# For even k
# for p in range(k.size(2)):
# for q in range(k.size(3)):
# k[i, j, p, q] = k[i, j, 2-p, 2-q]
# For odd k
# for p in range(k.size(2)):
# k[i, j, p, 0] = -k[i, j, p, 2]
# k[i, j, p, 1] = 0
# for q in range(k.size(3)):
# k[i, j, 0, q] = -k[i, j, 2, q]
# k[i, j, 1, q] = 0
pass
### Spatial domain convolution
padx = F.pad(x, [1,1,1,1])
sdc = F.conv2d(padx.unsqueeze(0), k)
### Frequency domain convolution
# Transform input
fdx = fft.rfft2(x)
sdfdx = fft.irfft2(fdx)
# Transform kernel
size_diff = x.size(-1)-k.size(-1)
padk = torch.roll(F.pad(k, [0,size_diff,0,size_diff]), (-1,-1), (-1, -2))
fdk = fft.rfft2(padk)
# Frequency domain multiplication
fdc = fdk * fdx
fdc = fdc.squeeze(0)
# Back to spatial domain
sdfdc = fft.irfft2(fdc)
### Compare
print("min max with s-domain conv:", sdc.min().item(), sdc.max().item())
print("min max with f-domain mul:", sdfdc.min().item(), sdfdc.max().item())
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|