'U-Net pytorch model outputting nan for MSE but not L1?

I’m training a U-Net (model below) on 10 different datasets. When I train using L1 loss, I receive no errors however, for one of the datasets using MSE loss the network outputs tensor[nan…] and cannot train after ~5-10th epoch. I have tried gradient clipping with arbitrary values of 0.5 and 1 - neither worked. I've also checked and none of the inputs contain nan values

Any help would be appreciated Thank you


class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(config.numinputs, 64,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(64))
    
        self.layer2 = nn.Sequential( 
            nn.Conv2d(64, 64,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(64))
        
        self.layer3 = nn.Sequential(
            nn.MaxPool2d(2, stride=2, padding=0))
            
        self.layer4 = nn.Sequential(
            nn.Conv2d(64,128,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(128))

        self.layer5 = nn.Sequential(
            nn.Conv2d(128, 128,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(128))
            
        self.layer6 = nn.Sequential(
            nn.MaxPool2d(2, stride=2, padding=0))
            
        self.layer7 = nn.Sequential(
            nn.Conv2d(128, 256,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(256))
            
        self.layer8 = nn.Sequential(
            nn.Conv2d(256, 256,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(256))
            
        self.layer9 = nn.Sequential(
            nn.Conv2d(256, 256,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(256))
            
        self.layer10 = nn.UpsamplingBilinear2d(scale_factor=2)
            
        self.layer11 = nn.Sequential(
            nn.Conv2d(256, 128,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(128))
            
        self.layer12 = nn.Sequential(
            nn.Conv2d(128, 128,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(128))
            
        self.layer13 = nn.Sequential(
            nn.Conv2d(128, 128,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(128))
            
        self.layer14 = nn.UpsamplingBilinear2d(scale_factor=2)
            
        self.layer15 = nn.Sequential(
            nn.Conv2d(128, 64,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(64))
            
        self.layer16 = nn.Sequential(
            nn.Conv2d(64, 64,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(64))
            
        self.layer17 = nn.Sequential(
            nn.Conv2d(64, 1,kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(),
            nn.BatchNorm2d(1))
        
        self.layer18 = nn.Softmax(dim=1)
            
            
    def forward(self,x):
        x = self.layer1(x)  
        x = self.layer2(x)  
        x = self.layer3(x)  
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)  
        x = self.layer7(x)
        x = self.layer8(x)
        x = self.layer9(x)  
        x = self.layer10(x)  
        x = self.layer11(x)  
        x = self.layer12(x)  
        x = self.layer13(x)  
        x = self.layer14(x)  
        x = self.layer15(x)  
        x = self.layer16(x)
        x = self.layer17(x) 
        x = x.view(1, -1)
        x = self.layer18(x)
        x = x.reshape(1, 512, 512)
        
        return x



Solution 1:[1]

From my knowledge, nor MSE or L1 losses are suited to train a U-net, mainly because this model performs semantic segmentation and outputs a probability distribution.

Loss functions such as Categorical/Binary-Cross Entropy or Focal Loss (if you have class imbalance) will definitely give you better results!

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 FrenchPerceptron