'PyTorch: Multi-class segmentation loss value != 0 when using target image as the prediction

I was performing semantic segmentation using PyTorch. There are a total of 103 different classes in the dataset and the targets are RGB images with only the Red channel containing the labels. I was using nn.CrossEntropyLoss as my loss function. For sanity, I wanted to check if using nn.CrossEntropyLoss is correct for this problem and whether it has the expected behaviour

I pick a random mask from my dataset and create a categorical version of it using this custom transform

class ToCategorical:
    def __init__(self, n_classes: int) -> None:
        self.n_classes = n_classes

    def __call__(self, sample: torch.Tensor):
        mask = sample.permute(1, 2, 0)
        categories = torch.unique(mask).tolist()[1:]  # get all categories other than 0
        # build a tensor with `n_classes` channels
        one_hot_image = torch.zeros(self.n_classes, *mask.shape[:-1])
        for category in categories:
            # get spacial locs where the categ is present
            rows, cols, _ = torch.where(mask == category)
            # in same spacial loc but in `categ` channel fill 1
            one_hot_image[category, rows, cols] = 1  
        return one_hot_image

And then I send this image as the output (prediction) and use the ground truth mask as the target to the loss function.

import torch.nn as nn

mask = T.PILToTensor()(Image.open("path_to_image").convert("RGB"))
categorical_mask = ToCategorical(103)(mask).unsqueeze(0)
mask = mask[0].unsqueeze(0)  # get only the red channel, add fake batch_dim

loss_fn = nn.CrossEntropyLoss()

target = mask
output = categorical_mask

print(output.shape, target.shape)
print(loss_fn(output, target.to(torch.long)))

I expected the loss to be zero but to my surprise, the output is as follows

torch.Size([1, 103, 600, 800]) torch.Size([1, 600, 800])
tensor(4.2836)

I verified with other samples in the dataset and I obtained similar values for other masks as well. Am I doing something wrong? I expect the loss to be = 0 when the output is the same as the target.

PS. I also know that nn.CrossEntropyLoss is the same as using log_softmax followed by nn.NLLLoss() but even I obtained the same value by using nllloss as well

For Reference

Dataset used: UECFoodPixComplete



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source