'Pytorch RuntimeError: Error(s) in loading state_dict for SimCLR:

I am trying to train my SimCLR model using LSTM network and getting the error: Pytorch RuntimeError: Error(s) in loading state_dict for SimCLR: It appears as if the state_dictionary in LSTM model is not loaded since the model is not pretrained. How to correct the error?

RuntimeError: Error(s) in loading state_dict for SimCLR:
    Missing key(s) in state_dict: "lstmnet.lstm.weight_ih_l0", "lstmnet.lstm.weight_hh_l0", "lstmnet.lstm.bias_ih_l0", "lstmnet.lstm.bias_hh_l0", "lstmnet.fc.weight", "lstmnet.fc.bias". 
    size mismatch for convnet.fc.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([2, 128]).
    size mismatch for convnet.fc.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([2]).

Below is the code:

class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(LSTMModel, self).__init__()
        # Hidden dimensions
        self.hidden_dim = hidden_dim

        # Number of hidden layers
        self.layer_dim = layer_dim

        # Building your LSTM
        # batch_first=True causes input/output tensors to be of shape
        # (batch_dim, seq_dim, feature_dim)
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
        self.dropout = nn.Dropout(0.1)
        # Readout layer
        self.fc = nn.Linear(hidden_dim, output_dim)
        

    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()

        # Initialize cell state
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()

        # 28 time steps
        # We need to detach as we are doing truncated backpropagation through time (BPTT)
        # If we don't, we'll backprop all the way to the start even after going through another batch
        x, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        x = self.dropout(x) 
        # Index hidden state of last time step
        # out.size() --> 100, 28, 100
        # out[:, -1, :] --> 100, 100 --> just want last time step hidden states! 
        x = self.fc(x[:, -1, :]) 
        # out.size() --> 100, 10
       
        return (x)    

class SimCLR(pl.LightningModule):
    
    def __init__(self, hidden_dim, lr, temperature, weight_decay, max_epochs=10):
        super().__init__()
        self.save_hyperparameters()
        assert self.hparams.temperature > 0.0, 'The temperature must be a positive float!'
        # Base model f(.)
        input_dim = 3
        hidden_dim = hidden_dim
        layer_dim = 1  # ONLY CHANGE IS HERE FROM ONE LAYER TO TWO LAYER
        output_dim = 2

        self.lstmnet = LSTMModel(input_dim, hidden_dim, layer_dim, output_dim )  # Output of last linear layer
        self.convnet = torchvision.models.resnet18(pretrained=False, 
                                                   num_classes=4*hidden_dim)
        # The MLP for g(.) consists of Linear->ReLU->Linear 
        self.convnet.fc = nn.Sequential(
            self.lstmnet.fc,  # Linear(ResNet output, 4*hidden_dim)
            nn.ReLU(inplace=True),
            nn.Linear(4*hidden_dim, hidden_dim)
        )


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source