'Train a model to output weights of another model, and use the other model just as function evaluation

I have 2 models, A and B. A(x1)=Weights of B B(x2)=Final output

A is trainable B is not trainable (I just want to upload the outputs of A into B and infer)

Problem I am facing: Output of A is torch.tensor. While setting the weights of B, I had to slice the output tensor of A. However, I am losing the gradient flow, from final loss to weights of A, hence there is no training happening. How do I implement the idea or correct my code?

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.autograd import Variable
import numpy as np

class Hyper_Model(nn.Module):

    def __init__(self):

        super(Hyper_Model, self).__init__()
        self.layers = nn.Sequential(nn.Linear(1,32),
                      nn.ReLU(),             
                      nn.Linear(32,32),
                      nn.ReLU(),
                      nn.Linear(32,32),
                      nn.ReLU(),
                      nn.Linear(32,32),
                      nn.ReLU(),
                      nn.Linear(32,177))

  
    def forward(self,param):        
        param_ = self.layers(param)            
        return param_

class Main_Model(nn.Module):

    def __init__(self):

        super(Main_Model, self).__init__()
        self.linear1 = nn.Linear(2,8)
        self.linear2 = nn.Linear(8,8)
        self.linear3 = nn.Linear(8,8)
        self.out = nn.Linear(8,1)

    def forward(self,param_,x):
        self.linear1.weight = torch.nn.Parameter(param_[0,:16].view(8,2))
        self.linear2.weight = torch.nn.Parameter(param_[0,24:88].view(8,8))
        self.linear3.weight = torch.nn.Parameter(param_[0,96:160].view(8,8))
        self.linear1.bias = torch.nn.Parameter(param_[0,16:24].view(8))
        self.linear2.bias = torch.nn.Parameter(param_[0,88:96].view(8))
        self.linear3.bias = torch.nn.Parameter(param_[0,160:168].view(8))
        self.out.weight = torch.nn.Parameter(param_[0,168:176].view(1,8))
        self.out.bias = torch.nn.Parameter(param_[0,176:].view(1))

        self.linear1.weight.requires_grad = False
        self.linear2.weight.requires_grad = False
        self.linear3.weight.requires_grad = False        
        self.linear1.bias.requires_grad = False
        self.linear2.bias.requires_grad = False
        self.linear3.bias.requires_grad = False
        self.out.weight.requires_grad =  False
        self.out.bias.requires_grad =  False

        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x)) 
        x = self.out(x)
        return x

x = torch.tensor([1.0,2.0,3.0],requires_grad=True).view(3,1)
t = torch.tensor([1.0,1.5,2.0],requires_grad=True).view(3,1)
param = torch.tensor([-0.01]).view(1,1)
X = torch.cat([x,t],dim=1)
Y = torch.tensor([5.0,6.0,9.0]).view(3,1)
h = Hyper_Model()
m = Main_Model()
opt = torch.optim.Adam(list(h.parameters()), lr=0.001)
loss_func = nn.MSELoss()

for i in range(10):
    opt.zero_grad()
    param_ = h(param)   

    out = m(param_,X)
    loss = loss_func(out,Y)

    print(i,loss)

    loss.backward()
    opt.step()


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source