'Did I initialize my LSTM correct or do my states reset all the time?

Currently looking to reset my model for every fold in a Cross Validation when I realized I might initialized my model all wrong. I am very confused right now because it looks like every forward pass in my model resets my hidden state and cell state to zeros. I think I added the init_hidden function when I went from only training with batch size =1 to different batch sizes.

Currently rethinking what I did the last week. Don't know if my model is working like this or not.

class Model_GRU_1(nn.Module):

    def __init__(self, n_features, n_classes, n_hidden, n_layers,dropout):            
        super().__init__()
        self.gru = nn.GRU(
            input_size=n_features,
            hidden_size=n_hidden,
            num_layers=n_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=True
        )

        
        self.dense = nn.Linear(n_hidden, n_hidden)
        self.relu = nn.ReLU()
        weight = torch.zeros(n_layers,n_hidden)
        nn.init.kaiming_uniform_(weight)
        self.weight = nn.Parameter(weight)

        self.classifier = nn.Linear(n_hidden, n_classes)


    def init_hidden(self):       
        hidden_state = torch.zeros(self.gru.num_layers,batch_size,self.gru.hidden_size)
        cell_state = torch.zeros(self.gru.num_layers,batch_size,self.gru.hidden_size)
        return (hidden_state, cell_state)

    def forward(self, x):
        self.hidden = self.init_hidden()
        _, (hidden) = self.gru(x)                  
        out=hidden[-1]
        out2 = self.dense(out)
        out3 = self.relu(out2)                                 
        return self.classifier(out3)


Solution 1:[1]

You're currently resetting your hidden states to zero on each batch, but you're not passing them to self.gru.forward which has optional hx argument. When it is None, pytorch will initialize it for you -- it is initialized to zero, so it would be the same as if you passed your hidden state on the first pass. If you're happy with your results using this model, you can just as easily remove anything relating to self.hidden including init_hidden.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 erip