'How to freeze an entire model in PyTorch for regularization?
I am trying to train a CNN on two different datasets, Dataset A and Dataset B, consecutively. First, I have trained my network on Dataset A and it achieved good results. Now, I want to train the same model on Dataset B. For B, however, I want to add to the loss a regularization term that forces the model to stay near the parameters it got for Dataset A. So something similar to the loss:
Loss while training on Dataset B
Where the first term in the equation is the Loss on Dataset B and the second term is for regularizing the parameters Theta to be close to Theta^* (which I got after training on Dataset A).
I tried to implement this code:
class ParamImportance:
def __init__(self, parameters, oldmodel):
self.parameters = parameters
self.oldparameters = oldmodel
self.delta = delta
self.oldparaList = []
def OldParamList(self):
""" Save the old parameters (Trained on Dataset A) in a list """
for oldparameter in self.oldparameters:
oldparameter.requires_grad = False
holder = copy.deepcopy(oldparameter)
self.oldparaList.append(holder) # .detach())
def pri(self):
Distance_OldModel = 0
counter = 0
for parameter in self.parameters:
Distance_OldModel += ((parameter - self.oldparaList[counter])**2).mean()
counter += 1
return Distance_OldModel
First I loaded the model I trained on Dataset A two times (One for continuing the training on Dataset B and the other for regularization). Then, before training, I instantiated the ParamImportance class with both models. During training, I called pri from ParamImportance class to calculate the regularization term and add it to the total loss as in this code:
for batch_idx, (data, targets) in enumerate(loader):
data = data.float().to(device=DEVICE)
targets = targets.float().unsqueeze(1).to(device=DEVICE)
# forward
with torch.cuda.amp.autocast():
predictions = model(data)
loss = loss_fn(predictions, targets)
L2_Distance = importance.pri()
loss = loss + L2_Distance
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
The issue is that the code calculated the regularization loss only for the first iteration, then it would just return zero (Not a tensor). I am not sure how should I approach this problem and would appreciate any help.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
