'PyTorch model input shape
I loaded a custom PyTorch model and I want to find out its input shape. Something like this:
model.input_shape
Is it possible to get this information?
Update: print() and summary() don't show this model's input shape, so they are not what I'm looking for.
Solution 1:[1]
print(model)
Will give you a summary of the model, where you can see the shape of each layer.
You can also use the pytorch-summary package.
If your network has a FC as a first layer, you can easily figure its input shape. You mention that you have a Convolutional layer at the front. With Fully Connected layers present too, the network will produce output for only one specific input size. I'm proposing to figure this out by using various shapes, i.e. feeding a toy batch with some shape, and then checking the output of the Conv layer just before the FC layer.
As this depends on the architecture of the net before the first FC layer (num of conv layers, kernels, etc), I can't give you an exact formula for the correct input. As mentioned, you have to figure this out by experimenting with various input shapes, and the resulting net's output before the first FC. There's (almost) always a way to solve something with code, but I can't think of something else right now.
Solution 2:[2]
You can get input shape from first tensor in model parameters.
For example create some model:
class CustomNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(1568, 256)
self.fc2 = nn.Linear(256, 256)
self.fc3 = nn.Linear(256, 20)
def forward(self, x):
out = self.fc1(x)
out = F.relu(out)
out = self.fc2(out)
out = F.relu(out)
out = self.fc3(out)
return out
model = CustomNet()
So model.parameters() method returns an iterator over module parameters of torch.Tensor class. Look at the docs https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.parameters
And first parameter is the input tensor.
first_parameter = next(model.parameters())
input_shape = first_parameter.size()
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | |
| Solution 2 |
