'CUDA memory issues in PyTorch

I'm trying to fine tune BERT for a classification task in Google Colab, and I often get the CUDA out of memory error. But this is a bit strange:

  1. If I run my training loop with batch size, say, 20 it might initially be okay for a while. Then suppose I stop the code and try again, it'll give me CUDA memory issues. Even if I reduce the batch size to 1, it gives me the same error. However, if I restart runtime (or factory restart), it runs fine. What causes memory issues over multiple runs? I don't think I'm storing anything, so the only thing which should matter is batch size.
  2. My Colab session often crashes due to RAM after multiple runs, but works okay for a bit after restarting runtime.

Here is the code I'm using for training. The class Model is my model, which consists of pre-trained BERT (Huggingface transformers library) and a few layers on top.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4) 
criterion = nn.BCELoss().to(device)

batch_size = 32
train_data_loader = utils.data.DataLoader(data, batch_size = batch_size) #data is subclass of torch.utils.data.Dataset

def train(n_epochs): 
    model.train()
    for epoch in range(n_epochs): 
        total_loss = 0
        for embeddings, labels in train_data_loader: 
            labels = (labels.double()).to(device)
            input_ids = (embeddings['input_ids'].squeeze(1)).to(device)
            attention_mask = (embeddings['attention_mask'].squeeze(1)).to(device)
            output = model(input_ids, attention_mask)
            loss = criterion(output, labels)
            total_loss += loss.item()
            model.zero_grad()
            loss.backward()
            optimizer.step()
        
        print(f'epoch {epoch+1}, total loss = {total_loss}')

total_epochs = 10
train(total_epochs)

Is there anything wrong or inefficient about my training loop causing the error?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source