'loss implementation for combining two models

I made model for time-series regression model. I used ConvAutoencoder as Denoising model, and SCINet(which is main regression model) uses output of ConvAE as target. (as below figure) enter image description here

And I implemented train session as below code.

sci_model = SCINet(output_len = output_len, input_len= input_len, input_dim = input_dim, hid_size = hid_size, num_stacks = num_stacks,
                num_levels = num_levels, concat_len = concat_len, groups = groups, kernel = kernel, dropout = dropout,
                 single_step_output_One = 0, positionalE =  True, modified = True).to(device)
ae_model = ConvAutoEncoder().to(device)
loss_function = torch.nn.MSELoss()
sci_optimizer = torch.optim.AdamW(sci_model.parameters(), lr=lr)
ae_optimizer = torch.optim.AdamW(ae_model.parameters(), lr=lr)
all_params = list(ae_model.parameters()) + list(sci_model.parameters())
total_optimizer = torch.optim.AdamW(all_params, lr = lr)

loss_list = []
ae_loss_list = []
sci_loss_list = []
EPOCHS = 50
max_norm = 10
alpha = 10
for epoch in range(1, EPOCHS+1):
    total_loss = 0
    n = 0

    for train_X, train_y in train_loader:
        input_data = torch.reshape(torch.transpose(train_y,0,1) ,(1,1,train_y.shape[0])).repeat(128,1,1) # 128 = sequence_length
        ae_model.train()
        pred = ae_model(input_data.to(device))       

        sci_input_data = train_X
        sci_target = torch.transpose(pred[0],0,1)
        sci_model.train()
        sci_pred = sci_model(sci_input_data.to(device))
        
        ae_loss = loss_function(pred, input_data.to(device))
        sci_loss = loss_function(sci_pred, sci_target.to(device))

        # sci_optimizer.zero_grad()
        # ae_optimizer.zero_grad()
        total_optimizer.zero_grad()
        
        # ae_loss.backward(retain_graph=True)
        # sci_loss.backward()
        # multi_loss = alpha * ae_loss * input_data.size()[0]  + sci_loss * sci_input_data.size()[0]
        multi_loss = alpha * ae_loss  + sci_loss * sci_input_data.size()[0]
        multi_loss.backward()
        

        # torch.nn.utils.clip_grad_norm_(ae_model.parameters(),max_norm)
        # torch.nn.utils.clip_grad_norm_(sci_model.parameters(),max_norm)
        torch.nn.utils.clip_grad_norm_(all_params,max_norm)

        # ae_optimizer.step()
        # sci_optimizer.step()
        total_optimizer.step()

        n += sci_input_data.size()[0]

    ae_total_loss, sci_total_loss, val_total_loss = 0, 0, 0
    val_n = 0
    
    torch.cuda.empty_cache()

    for valid_X, valid_y in valid_loader:
        input_data = torch.reshape(torch.transpose(valid_y,0,1) ,(1,1,valid_y.shape[0])).repeat(128,1,1)
        target = torch.reshape(torch.transpose(valid_y,0,1) ,(1,1,valid_y.shape[0])).repeat(128,1,1)
        with torch.no_grad():
            ae_model.eval()
            pred = ae_model(input_data.to(device))
            ae_loss = loss_function(pred, target.to(device))
            sci_input_data = valid_X
            sci_target = torch.transpose(pred[0],0,1)
            sci_model.eval()
            sci_pred = sci_model(sci_input_data.to(device))
            sci_loss = loss_function(sci_pred, sci_target.to(device))

        ae_total_loss += ae_loss
        sci_total_loss += sci_loss*sci_input_data.size()[0]
        val_total_loss += alpha * ae_loss + sci_total_loss
        val_n += sci_input_data.size()[0]

        
    loss_list.append(val_total_loss/val_n)
    ae_loss_list.append(ae_total_loss/val_n)
    sci_loss_list.append(sci_total_loss/val_n)
    
    if epoch == 1:
        min_loss = loss_list[0]
        torch.save(sci_model.state_dict(), f'./models/SCINet')
    else:
        if min_loss > loss_list[epoch - 1]:
            torch.save(sci_model.state_dict(), f'./models/SCINet')
            min_loss = loss_list[epoch - 1]

As you can see at the codes, I tried two ways for construct losses. One way is total_loss = loss1(AE loss) + loss2(SCINet loss),

and the other way is loss1 = AELoss(retain_graph=true) loss2 = SCILoss which is now commented in upper code block.

However, Auto encoder doesn't train well in both ways. It just outputs 0 when I predicted by final models

Can someone help me to train auto encoder and scinet properly?

PS. sorry for my poor English skills.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source