'pytorch BCEWithLogitsLoss calculating pos_weight
I have a neural network as below for binary prediction. My classes are heavily imbalanced and class 1 occurs only 2% of times. Showing last few layers only
self.batch_norm2 = nn.BatchNorm1d(num_filters)
self.fc2 = nn.Linear(np.sum(num_filters), fc2_neurons)
self.batch_norm3 = nn.BatchNorm1d(fc2_neurons)
self.fc3 = nn.Linear(fc2_neurons, 1)
My loss is as below. Is this a correct way to calculate pos_weight parameter? I looked into official documentation at this link and it shows that pos_weight needs to have one value for each class for multiclass classification. Not sure if for the binary class it is a difference scenario. I tried to input 2 values and I was getting an error
My question: for binary problem, would pos_weight be a single value unlike multiclass classification where it needs to a list/array with length equal to number of classes?
BCE_With_LogitsLoss=nn.BCEWithLogitsLoss(pos_weight=class_wts[0]/class_wts[1])
My y variable is a single variable that has 0 or 1 to represent the actual class and the neural network outputs a single value
--------------------------------------------------Update 1
based upon the answer by Shai I have below questions:
BCEWithLogitsLoss- if it is a multiclass problem then how to usepos_weighparameter?- Is there any example of using focal loss in pytorch? I found some links but most of them were old - dating 2 or 3 or more years
- For training I am oversampling my class 1. Is focal loss still appropiate?
Solution 1:[1]
The documentation of pos_weight is indeed a bit unclear. For BCEWithLogitsLoss pos_weight should be a torch.tensor of size=1:
BCE_With_LogitsLoss=nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_wts[0]/class_wts[1]]))
However, in your case, where pos class occurs only 2% of the times, I think setting pos_weight will not be enough.
Please consider using Focal loss:
Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár Focal Loss for Dense Object Detection (ICCV 2017).
Apart from describing Focal loss, this paper provides a very good explanation as to why CE loss performs so poorly in the case of imbalance. I strongly recommend reading this paper.
Other alternatives are listed here.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Shai |
