'Why is my pretrained BERT model always predicting the most frequent tokens (including [PAD])?

I am trying to further pretrain a Dutch BERT model with MLM on an in-domain dataset (law-related). I have set up my entire preprocessing and training stages, but when I use the trained model to predict a masked word, it always outputs the same words in the same order, including the [PAD] token. Which is weird, because I thought it wasn't even supposed to be able to predict the pad-token at all (since my code makes sure pad-tokens are not masked). See picture of my models predictions I have tried to use more data (more than 50.000 instances) and more epochs (about 20). I have gone through my code and am pretty sure that it gives the right input to the model. The English version of the model seems to work, which makes me wonder if the Dutch model is less robust.

Would anyone know any possible causes/solutions for this? Or is it possible that my language model just simply doesn't work?

I will add my training loop and mask-function just in case I overlooked a mistake in them:

def mlm(tensor):
    rand = torch.rand(tensor.shape)
    mask_arr = (rand < 0.15) * (tensor > 3)
    for i in range(tensor.shape[0]):
        selection = torch.flatten(mask_arr[i].nonzero()).tolist()
        tensor[i, selection] = 4
    return tensor    

model.train()
    optim = optim.Adam(model.parameters(), lr=0.005)
    epochs = 1
    losses = []
    for epoch in range(epochs):
        epochloss = []
        loop = tqdm(loader, leave=True)
        for batch in loop:
            optim.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels = labels)
            loss = outputs.loss
            epochloss.append(loss)
            loss.backward()
            optim.step()
            loop.set_description(f'Epoch {epoch}')
            loop.set_postfix(loss=loss.item())
        losses.append(epochloss)


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source