'How to avoid LSTM classifying into only one class?

I'm trying to make auto-punctation in ENG texts. I have following network:

class CharLSTM(nn.Module):
def __init__(self, voc_size, emb_dim, hidden_size, output_size, num_layers=1):
    super().__init__()

    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.emb = nn.Embedding(voc_size, emb_dim)
    self.fc3 = nn.Linear(emb_dim, emb_dim)
    self.rnn = nn.LSTM(emb_dim, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True)
    self.fc2 = nn.Linear(hidden_size * 2, output_size*10)
    self.do = nn.Dropout(0.4)
    self.fc = nn.Linear(output_size*10, output_size)
    
    self.reset_hidden()
    self.output_size = output_size
    
def forward(self, inp):
    inp = inp.to(device)        
    data = self.emb(inp).to(device)
    if data.ndim == 1:
        data = self.fc3(data)
        score, (h0, c0) = self.rnn(data.reshape(1, 1, -1), (self.hidden1, self.hidden2))
        self.hidden1 = h0
        self.hidden2 = c0
        score = self.fc2(score)
        score = self.do(score)
        score = F.relu(score)
        score = self.fc(score)
    else:
        for x in data:
            score, (h0, c0) = self.rnn(x.reshape(1, 1, -1), (self.hidden1, self.hidden2))
            self.hidden1 = h0
            self.hidden2 = c0
            score = self.fc2(score)
            score = self.do(score)
            score = F.relu(score)
            score = self.fc(score)
    
    return score

def reset_hidden(self):
    
    self.hidden1 = torch.zeros(self.num_layers*2, 1, self.hidden_size).to(device)
    self.hidden2 = torch.zeros(self.num_layers*2, 1, self.hidden_size).to(device)

My dataset contains lines of text and each line contains 400 characters.

Now im trying to feed my network char by char and i expect four class on output:

  • No punctation
  • Character "."
  • Character ","
  • Character "?"

But my network predicts always "no punctation".

I know 95% of dataset contains class "No punctation", so what i have to do, to predict all 4 classes?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source