'Pytorch error mat1 and mat2 shapes cannot be multiplied
I receive this error. Whereas the size of my input image is [3072,2,2], so I flatten the image by the following code, however, I received this error:
mat1 and mat2 shapes cannot be multiplied (6144x2 and 12288x512)
my code:
class NeuralNet(nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(12288 ,512)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(512, 3)
def forward(self, x):
out = torch.flatten(x,0)
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
model = NeuralNet().to(device)
# Train the model
total_step = len(my_dataloader)
for epoch in range(5):
for i, (images, labels) in enumerate(my_dataloader):
# Move tensors to the configured device
images = images.to(device)
print(type(images))
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
Solution 1:[1]
First, your in-features of the linear layer is not correct. The in-feature should be the last dim of your input. In this case, it should be nn.Linear(2,512)
class NeuralNet(nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(2 ,512)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(512, 3)
def forward(self, x):
out = torch.flatten(x,0)
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
Based on PyTorch document the torch.flatten(x,0) return shape of [3072*2,2] if you want to have the shape of [12288] as you linear in-features, you should use the torch.flatten(input, start_dim=0, end_dim=- 1)
class NeuralNet(nn.Module):
def __init__(self):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(12288 ,512)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(512, 3)
def forward(self, x):
out = torch.flatten(x)
out = self.fc1(x,start_dim=0, end_dim=- 1)
out = self.relu(out)
out = self.fc2(out)
return out
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | khashayar ehteshami |
