'Memory consumption accumulating in a while loop
I am a new python programmer, I am trying to run a while-loop until a condition is reached. The code runs well locally, but when I submit it to a HPC, it runs out of memory. I'm overwriting the same variable for each iteration.
I have tried to use garbage collect but it does not seam to work. Any ideas about how to deal with this issue?
while loss > epsilon:
optimizer.zero_grad()
pred_sai = net(x_data) # count * 50 : count * 50 + 50
y_pred_sai = net(y_data)
pred_sai = torch.cat([pred_sai, fixed_sai], dim=1)
y_pred_sai = torch.cat([y_pred_sai, y_fixed_sai], dim=1)
pred_sai_T = torch.transpose(pred_sai, 0, 1)
K_tilde = torch.mm(torch.pinverse(inv_N * torch.mm(pred_sai_T, pred_sai) + lambda_ * I), inv_N * torch.mm(pred_sai_T, y_pred_sai))
y_pred_sai_T = torch.transpose(y_pred_sai, 0, 1)
MSE = (y_pred_sai_T - torch.mm(K_tilde, pred_sai_T))** 2
loss = torch.sum(MSE)
loss.backward()
optimizer.step()
count += 1
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
