'binary_cross_entropy_with_logits: weight vs pos_weight, what are the differences?

According to Pytorch's documentation on binary_cross_entropy_with_logits, they are described as:

weight

weight (Tensor, optional) – a manual rescaling weight if provided it’s repeated to match input tensor shape

pos_weight

pos_weight (Tensor, optional) – a weight of positive examples. Must be a vector with length equal to the number of classes.

What are their differences? The explanation is quite vague. If I understands correctly, weight is individual weight for each pixel (class), wheres pos_weight is the weight for everything that's not background (negative pixel/zero)?

What if I set both parameters? For example:

import torch 

preds = torch.randn(4, 100, 50, 50)
target = torch.zeros((4, 100, 50, 50))
target[:, :, 10:20, 10:20] = 1

pos_weight = target * 100
pos_weight[pos_weight < 100] = 1
weight = target * 100
weight[weight < 100] = 1

loss1 = binary_cross_entropy_with_logits(preds, target, pos_weight=pos_weight, weight=weight)
loss2 = binary_cross_entropy_with_logits(preds, target, pos_weight=pos_weight)
loss3 = binary_cross_entropy_with_logits(preds, target, weight=weight)

loss1, loss2, and loss3, which one is the correct usage?

On the same subject, I was reading a paper that said:

To deal with the unbalanced negative and positive data, we dilate each keypoint by 10 pixels and use weighted cross-entropy loss. The weight for each keypoint is set to 100 while for non-keypoint pixels it is set to 1.

which one is the correct usage if according to the paper?

Thanks in advance for any explanation!



Solution 1:[1]

The pos_weight parameter allows you to balance the positive example thus controlling the tradeoff between recall and precision (see also). A detailed explanation can be found on this thread along with the explicit math expression. On the other hand, weight allows to weigh the different elements on a given batch.

Here is a minimal example:

>>> target = torch.ones([10, 64], dtype=torch.float32)
>>> output = torch.full([10, 64], 1.5)

>>> criterion = torch.nn.BCEWithLogitsLoss() # w/o weight
>>> criterion(output, target)
tensor(0.2014) # all batch elements weighted equally

>>> weight = torch.rand(10,1)
>>> criterion = torch.nn.BCEWithLogitsLoss(weight=weight) # w/ weight
>>> criterion(output, target)
tensor(0.0908) # per element weighting

Which is identical to doing:

>>> criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
>>> torch.mean(criterion(output, target)*weight)
tensor(0.0908)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Ivan