'How to create unnamed PyTorch parameters in state dict?

I am trying to load a model checkpoint (.ckpt file) for transfer learning. I do not have the model's source code, so I am trying to recreate it with PyTorch, like this:

import torch
import torch.nn as nn
import torch.nn.functional as F

class IngrDetNet(nn.Module):
    def __init__(self):
        super(IngrDetNet, self).__init__()
        self.fc1 = nn.Linear(n_ingr, 1024)
        self.fc2= nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, n_ingr)

    def forward(self, x):
        x = self.fc1(x)
        x = F.leaky_relu(x, 0.2)
        x = self.fc2(x)
        x = F.leaky_relu(x, 0.2)
        x = self.fc3(x)
        x = F.leaky_relu(x, 0.2)
        x = self.fc4(x)
        
        ingrs = F.sigmoid(x)
        return ingrs

# Create a basic model instance
model = IngrDetNet()
model = torch.nn.DataParallel(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

The syntax is based on this PyTorch tutorial.

Then, I am trying to load in the checkpoint's state dict, based on this PyTorch tutorial:

model_path = Path('../Models/PITA/medr2idalle41.ckpt')
model.load_state_dict(torch.load(model_path, map_location=map_loc)['weights_id'])

But I get a key error mismatch:

Error(s) in loading state_dict for DataParallel:
    Missing key(s) in state_dict: "module.fc1.weight", "module.fc1.bias", "module.fc2.weight", "module.fc2.bias", "module.fc3.weight", "module.fc3.bias", "module.fc4.weight", "module.fc4.bias". 
    Unexpected key(s) in state_dict: "module.model.0.weight", "module.model.0.bias", "module.model.2.weight", "module.model.2.bias", "module.model.4.weight", "module.model.4.bias". 

The checkpoint state dict contains indexed keys such as module.model.0.weight, whereas my architecture contains named parameters such as module.fc1.weight. How do I generate layers in such a way that my parameters are not named but indexed?



Solution 1:[1]

The reason why there is a full mismatch of the keys is that you are using the nn.DataParallel module utility. This means it will wrap your original parent model under a wrapper "model" nn.Module. In other words:

>>> model = IngrDetNet()                  # model is a IngrDetNet
>>> model = torch.nn.DataParallel(model)  # model.model is a IngrDetNet

This in turn means your initialized model ends up with a prefixed "model." in its state dict keys.


You can fix this effect by changing the keys yourself before applying them on the model. A dict comprehension should do:

>>> state = torch.load(model_path, map_location=map_loc)
>>> state = {f'model.{k}': v for k, v in state['weights_id'].items()}

>>> model.load_state_dict(state)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1