'How to freeze an entire model in PyTorch for regularization?

I am trying to train a CNN on two different datasets, Dataset A and Dataset B, consecutively. First, I have trained my network on Dataset A and it achieved good results. Now, I want to train the same model on Dataset B. For B, however, I want to add to the loss a regularization term that forces the model to stay near the parameters it got for Dataset A. So something similar to the loss:

Loss while training on Dataset B

Where the first term in the equation is the Loss on Dataset B and the second term is for regularizing the parameters Theta to be close to Theta^* (which I got after training on Dataset A).

I tried to implement this code:

class ParamImportance:
def __init__(self, parameters, oldmodel):
    self.parameters    = parameters
    self.oldparameters = oldmodel
    self.delta         = delta
    self.oldparaList   = []

def OldParamList(self):
    """ Save the old parameters (Trained on Dataset A) in a list  """
    for oldparameter in self.oldparameters:
        oldparameter.requires_grad = False
        holder = copy.deepcopy(oldparameter)
        self.oldparaList.append(holder) # .detach())



def pri(self):
    Distance_OldModel = 0
    counter = 0
    for parameter in self.parameters:
        Distance_OldModel += ((parameter - self.oldparaList[counter])**2).mean()
        counter += 1
    return Distance_OldModel

First I loaded the model I trained on Dataset A two times (One for continuing the training on Dataset B and the other for regularization). Then, before training, I instantiated the ParamImportance class with both models. During training, I called pri from ParamImportance class to calculate the regularization term and add it to the total loss as in this code:

  for batch_idx, (data, targets) in enumerate(loader):

    data    = data.float().to(device=DEVICE)
    targets = targets.float().unsqueeze(1).to(device=DEVICE)


    # forward
    with torch.cuda.amp.autocast():
        predictions = model(data)

        loss = loss_fn(predictions, targets)
        L2_Distance = importance.pri()
        loss = loss + L2_Distance
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

The issue is that the code calculated the regularization loss only for the first iteration, then it would just return zero (Not a tensor). I am not sure how should I approach this problem and would appreciate any help.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source