'How to create unnamed PyTorch parameters in state dict?
I am trying to load a model checkpoint (.ckpt file) for transfer learning. I do not have the model's source code, so I am trying to recreate it with PyTorch, like this:
import torch
import torch.nn as nn
import torch.nn.functional as F
class IngrDetNet(nn.Module):
def __init__(self):
super(IngrDetNet, self).__init__()
self.fc1 = nn.Linear(n_ingr, 1024)
self.fc2= nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 256)
self.fc4 = nn.Linear(256, n_ingr)
def forward(self, x):
x = self.fc1(x)
x = F.leaky_relu(x, 0.2)
x = self.fc2(x)
x = F.leaky_relu(x, 0.2)
x = self.fc3(x)
x = F.leaky_relu(x, 0.2)
x = self.fc4(x)
ingrs = F.sigmoid(x)
return ingrs
# Create a basic model instance
model = IngrDetNet()
model = torch.nn.DataParallel(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
The syntax is based on this PyTorch tutorial.
Then, I am trying to load in the checkpoint's state dict, based on this PyTorch tutorial:
model_path = Path('../Models/PITA/medr2idalle41.ckpt')
model.load_state_dict(torch.load(model_path, map_location=map_loc)['weights_id'])
But I get a key error mismatch:
Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.fc1.weight", "module.fc1.bias", "module.fc2.weight", "module.fc2.bias", "module.fc3.weight", "module.fc3.bias", "module.fc4.weight", "module.fc4.bias".
Unexpected key(s) in state_dict: "module.model.0.weight", "module.model.0.bias", "module.model.2.weight", "module.model.2.bias", "module.model.4.weight", "module.model.4.bias".
The checkpoint state dict contains indexed keys such as module.model.0.weight, whereas my architecture contains named parameters such as module.fc1.weight. How do I generate layers in such a way that my parameters are not named but indexed?
Solution 1:[1]
The reason why there is a full mismatch of the keys is that you are using the nn.DataParallel module utility. This means it will wrap your original parent model under a wrapper "model" nn.Module. In other words:
>>> model = IngrDetNet() # model is a IngrDetNet
>>> model = torch.nn.DataParallel(model) # model.model is a IngrDetNet
This in turn means your initialized model ends up with a prefixed "model." in its state dict keys.
You can fix this effect by changing the keys yourself before applying them on the model. A dict comprehension should do:
>>> state = torch.load(model_path, map_location=map_loc)
>>> state = {f'model.{k}': v for k, v in state['weights_id'].items()}
>>> model.load_state_dict(state)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 |
