'binary_cross_entropy_with_logits: weight vs pos_weight, what are the differences?
According to Pytorch's documentation on binary_cross_entropy_with_logits, they are described as:
weight
weight (Tensor, optional) – a manual rescaling weight if provided it’s repeated to match input tensor shape
pos_weight
pos_weight (Tensor, optional) – a weight of positive examples. Must be a vector with length equal to the number of classes.
What are their differences? The explanation is quite vague. If I understands correctly, weight is individual weight for each pixel (class), wheres pos_weight is the weight for everything that's not background (negative pixel/zero)?
What if I set both parameters? For example:
import torch
preds = torch.randn(4, 100, 50, 50)
target = torch.zeros((4, 100, 50, 50))
target[:, :, 10:20, 10:20] = 1
pos_weight = target * 100
pos_weight[pos_weight < 100] = 1
weight = target * 100
weight[weight < 100] = 1
loss1 = binary_cross_entropy_with_logits(preds, target, pos_weight=pos_weight, weight=weight)
loss2 = binary_cross_entropy_with_logits(preds, target, pos_weight=pos_weight)
loss3 = binary_cross_entropy_with_logits(preds, target, weight=weight)
loss1, loss2, and loss3, which one is the correct usage?
On the same subject, I was reading a paper that said:
To deal with the unbalanced negative and positive data, we dilate each keypoint by 10 pixels and use weighted cross-entropy loss. The weight for each keypoint is set to 100 while for non-keypoint pixels it is set to 1.
which one is the correct usage if according to the paper?
Thanks in advance for any explanation!
Solution 1:[1]
The pos_weight parameter allows you to balance the positive example thus controlling the tradeoff between recall and precision (see also). A detailed explanation can be found on this thread along with the explicit math expression.
On the other hand, weight allows to weigh the different elements on a given batch.
Here is a minimal example:
>>> target = torch.ones([10, 64], dtype=torch.float32)
>>> output = torch.full([10, 64], 1.5)
>>> criterion = torch.nn.BCEWithLogitsLoss() # w/o weight
>>> criterion(output, target)
tensor(0.2014) # all batch elements weighted equally
>>> weight = torch.rand(10,1)
>>> criterion = torch.nn.BCEWithLogitsLoss(weight=weight) # w/ weight
>>> criterion(output, target)
tensor(0.0908) # per element weighting
Which is identical to doing:
>>> criterion = torch.nn.BCEWithLogitsLoss(reduction='none')
>>> torch.mean(criterion(output, target)*weight)
tensor(0.0908)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Ivan |
