'Problem understanding convolutions (conv(data)[i] == conv(data[i].unsqueeze(0))??)
import torch
import torch.nn as nn
data = torch.ones(3,3,6,6)
conv = nn.Conv2d(3, 16, kernel_size = 3, padding = 1)
print(data[0].unsqueeze(0).shape)
for i in range(3):
print((conv(data)[i] == conv(data[i].unsqueeze(0))).all())
Results:
torch.Size([1, 3, 6, 6])
tensor(False)
tensor(False)
tensor(False)
I thought it would print True but ended up printing False instead. Any idea why?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
