'Memory consumption accumulating in a while loop

I am a new python programmer, I am trying to run a while-loop until a condition is reached. The code runs well locally, but when I submit it to a HPC, it runs out of memory. I'm overwriting the same variable for each iteration.

I have tried to use garbage collect but it does not seam to work. Any ideas about how to deal with this issue?

while loss > epsilon:
        
    optimizer.zero_grad()


    pred_sai = net(x_data)  # count * 50 : count * 50 + 50
    y_pred_sai = net(y_data)
    
    
    pred_sai = torch.cat([pred_sai, fixed_sai], dim=1)
    y_pred_sai = torch.cat([y_pred_sai, y_fixed_sai], dim=1)


    pred_sai_T = torch.transpose(pred_sai, 0, 1)

    K_tilde = torch.mm(torch.pinverse(inv_N * torch.mm(pred_sai_T, pred_sai)  + lambda_ * I), inv_N * torch.mm(pred_sai_T, y_pred_sai))
       
    y_pred_sai_T = torch.transpose(y_pred_sai, 0, 1)

    
    MSE = (y_pred_sai_T - torch.mm(K_tilde, pred_sai_T))** 2  
    loss = torch.sum(MSE) 
      
        
    loss.backward()

    optimizer.step()

    count += 1



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source