'Mistake in training loop or data preparation for BERT-based model leads to non-convergence of loss
I'm trying to train a multimodal transformer (BERT-based) for text classification on both text and tabular data, but my models are reducing neither training nor validation loss in a significant way. When training these multimodal models I've noticed that they can't get to the same performance of a classifier that's using the tabular data only. To investigate this, I wrote a simple version of BertForSequenceClassification and noticed that this model (BERT + Linear classifier only) doesn't converge on training or validation loss. From this I deduced that I must've made some mistake in my training loop (I'm not using the transformers Trainer API as the multimodal-transformers requires transformers v3.1.0 which has a memory leak in that versions' Trainer class). My dataset and classification model are as follows:
class myDataset(torch.utils.data.Dataset):
def __init__(self, df, num_classes):
self.num_classes = num_classes
self.y = np.array(df["rating_num"]) # ints, 9 labels
self.length = len(self.y)
self.values = df["economy"] # texts
def __len__(self):
return self.length
def __getitem__(self, idx):
txt = self.values[idx]
label = self.y[idx]
return txt, label
def collate_fn(data, tokenizer=tokenizer, device=device):
texts, labels = zip(*data)
toks = tokenizer(texts, padding="max_length", truncation=True, return_tensors="pt")
labels = torch.LongTensor(labels)
return toks.to(device), labels.to(device)
train_dl = DataLoader(train_ds, shuffle=True, batch_size=BATCH_SIZE, collate_fn=collate_fn)
class myClassificationBERT(nn.Module):
def __init__(self, bert, output_dim, dropout):
super(myClassificationBERT, self).__init__()
self.bert=bert
clf_input_dim = bert.config.to_dict()["hidden_size"]
self.classifier = nn.Linear(clf_input_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
output = bert(**x)[1] # pooled_output
output = self.dropout(output)
output = self.classifier(output)
return output
As you can see, the model is fairly simple and is only meant to understand where the mistake in my pipeline is. I believe that the dataset and dataloader classes should be fine, since decoding the input_ids of the loaded data works fine. Model init and training loop follow:
# model init
bert = BertModel.from_pretrained("bert-base-uncased")
model = myClassificationBERT(bert=bert, output_dim=num_class, dropout=.2).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=5e-4)
loss_fct = nn.CrossEntropyLoss()
##########
#training
##########
best_val_loss = np.inf
best_epoch_idx = None
path = "models/mybert/"
num_labels = num_class
train_losses = {}
val_losses = {}
EPOCHS = 25
for epoch in range(EPOCHS):
model.train()
temp_loss = 0
for batch, y in train_dl:
preds = model(batch)
loss = loss_fct(preds.view(-1, num_labels), y.view(-1))
temp_loss += loss
# update model
loss.backward()
opt.step()
opt.zero_grad()
model.zero_grad()
temp_loss = temp_loss / len(train_dl)
train_losses[epoch] = temp_loss
model.eval()
# validation
with torch.no_grad():
val_loss = 0
for batch, y in val_dl:
preds = model(batch)
val_loss += loss_fct(preds.view(-1, num_labels), y.view(-1))
val_loss = val_loss/len(val_dl)
val_losses[epoch]=val_loss
if val_loss < best_val_loss:
torch.save(model.state_dict(), path+f"ep{epoch}_loss{val_loss}.pt")
best_val_loss = val_loss
best_epoch_index = epoch
print(f">>>>>>>Epoch {epoch}\tloss: {temp_loss}\tval_loss: {val_loss}")
Exemplary output for first 6 epochs:
>>>>>>>Epoch 0 loss: 2.0026087760925293 val_loss: 1.9571226835250854
>>>>>>>Epoch 1 loss: 2.002939224243164 val_loss: 1.9474005699157715
>>>>>>>Epoch 2 loss: 1.9835222959518433 val_loss: 2.017913579940796
>>>>>>>Epoch 3 loss: 2.0059845447540283 val_loss: 1.9957032203674316
>>>>>>>Epoch 4 loss: 1.9796814918518066 val_loss: 1.9774249792099
>>>>>>>Epoch 5 loss: 1.9920653104782104 val_loss: 2.1032602787017822
When running this loop, neither the training loss nor the validation loss reduces, instead staying inside the initial loss values with some (seemingly random) deviation. I know that such a simple Bert classifier should be able to classify my texts well, as I've tested it on the same data using a ClassificationModel from the simpletransformers library, which in turn uses a BertForSequenceClassification model and reaches a validation loss of 0.32 in just 10 epochs. This phenomenon also happens with different optimizers and learning rates. Therefore I must be doing something wrong in either my training loop or even beforehand.
Version info:
transformers 3.1.0
torch 1.8.1+cu111
multimodal-transformers 0.1.2-alpha
numpy 1.19.5
pandas 1.4.1
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
