'Mistake in training loop or data preparation for BERT-based model leads to non-convergence of loss

I'm trying to train a multimodal transformer (BERT-based) for text classification on both text and tabular data, but my models are reducing neither training nor validation loss in a significant way. When training these multimodal models I've noticed that they can't get to the same performance of a classifier that's using the tabular data only. To investigate this, I wrote a simple version of BertForSequenceClassification and noticed that this model (BERT + Linear classifier only) doesn't converge on training or validation loss. From this I deduced that I must've made some mistake in my training loop (I'm not using the transformers Trainer API as the multimodal-transformers requires transformers v3.1.0 which has a memory leak in that versions' Trainer class). My dataset and classification model are as follows:

class myDataset(torch.utils.data.Dataset):
    def __init__(self, df, num_classes):
        self.num_classes = num_classes
        self.y = np.array(df["rating_num"]) # ints, 9 labels
        self.length = len(self.y)   
        self.values = df["economy"] # texts
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        txt = self.values[idx]
        label = self.y[idx]
        return  txt, label

def collate_fn(data, tokenizer=tokenizer, device=device):
    texts, labels = zip(*data)
    toks = tokenizer(texts, padding="max_length", truncation=True, return_tensors="pt")
    labels = torch.LongTensor(labels)
    return toks.to(device), labels.to(device)

train_dl = DataLoader(train_ds, shuffle=True, batch_size=BATCH_SIZE, collate_fn=collate_fn)

class myClassificationBERT(nn.Module):
    def __init__(self, bert, output_dim, dropout):
        super(myClassificationBERT, self).__init__()
        
        self.bert=bert
        
        clf_input_dim = bert.config.to_dict()["hidden_size"]
        self.classifier = nn.Linear(clf_input_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
    
    
    def forward(self, x):
        output = bert(**x)[1] # pooled_output
        output = self.dropout(output)
        output = self.classifier(output)
        return output

As you can see, the model is fairly simple and is only meant to understand where the mistake in my pipeline is. I believe that the dataset and dataloader classes should be fine, since decoding the input_ids of the loaded data works fine. Model init and training loop follow:

# model init
bert = BertModel.from_pretrained("bert-base-uncased")
model = myClassificationBERT(bert=bert, output_dim=num_class, dropout=.2).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
loss_fct = nn.CrossEntropyLoss()

##########
#training
##########

best_val_loss = np.inf
best_epoch_idx = None
path = "models/mybert/"
num_labels = num_class

train_losses = {}
val_losses = {}

EPOCHS = 25
for epoch in range(EPOCHS):
    model.train()
    temp_loss = 0
    for batch, y in train_dl:
        preds = model(batch)
        loss = loss_fct(preds.view(-1, num_labels), y.view(-1))
        temp_loss += loss

        # update model
        loss.backward()
        opt.step()
        opt.zero_grad()
        model.zero_grad()
        
    
    temp_loss = temp_loss / len(train_dl)
    train_losses[epoch] = temp_loss
    model.eval()
    # validation
    with torch.no_grad():
        val_loss = 0
        for batch, y in val_dl:
            preds = model(batch)
            val_loss += loss_fct(preds.view(-1, num_labels), y.view(-1))
            
        val_loss = val_loss/len(val_dl)
        val_losses[epoch]=val_loss
        
        if val_loss < best_val_loss:
            torch.save(model.state_dict(), path+f"ep{epoch}_loss{val_loss}.pt")
            best_val_loss = val_loss
            best_epoch_index = epoch
    print(f">>>>>>>Epoch {epoch}\tloss: {temp_loss}\tval_loss: {val_loss}")

Exemplary output for first 6 epochs:

>>>>>>>Epoch 0  loss: 2.0026087760925293    val_loss: 1.9571226835250854
>>>>>>>Epoch 1  loss: 2.002939224243164 val_loss: 1.9474005699157715
>>>>>>>Epoch 2  loss: 1.9835222959518433    val_loss: 2.017913579940796
>>>>>>>Epoch 3  loss: 2.0059845447540283    val_loss: 1.9957032203674316
>>>>>>>Epoch 4  loss: 1.9796814918518066    val_loss: 1.9774249792099
>>>>>>>Epoch 5  loss: 1.9920653104782104    val_loss: 2.1032602787017822

When running this loop, neither the training loss nor the validation loss reduces, instead staying inside the initial loss values with some (seemingly random) deviation. I know that such a simple Bert classifier should be able to classify my texts well, as I've tested it on the same data using a ClassificationModel from the simpletransformers library, which in turn uses a BertForSequenceClassification model and reaches a validation loss of 0.32 in just 10 epochs. This phenomenon also happens with different optimizers and learning rates. Therefore I must be doing something wrong in either my training loop or even beforehand.

Version info:

transformers 3.1.0
torch 1.8.1+cu111
multimodal-transformers 0.1.2-alpha
numpy 1.19.5
pandas 1.4.1


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source