'Pytorch & BERT - Predict Multiple Binary Feature
I would like to add manualy an output layer to BERT in order to predict multiple features which are binary.
For example, these outputs would answer the questions:
- Is the text positive? 1 if yes, 0 otherwise.
- Is this text about sports? 1 if yes, 0 if not.
- Is this text about business? 1 if yes, 0 if not.
My first idea was to do it as a regression task, adding an output layer with 3 neurons: one for each question.
The pseudo-code:
def __init__(self):
self.bert = CamembertModel.from_pretrained('camembert-base')
self.regressor = nn.Sequential(nn.Linear(dim_in, 3))
def forward(self, input):
outputs = self.bert(input) # "Bert Layers"
outputs = self.regressor(outputs) # LinearSequential with output of 3
return outputs
But I'll get values above 1 and below 0. So, which kind of layer could I add in order to get probability (values between 0 and 1) to deal with this problem ?
I hope my question was clear, thank you for your help.
def __init__(self):
self.bert = CamembertModel.from_pretrained('camembert-base')
self.regressor = nn.Sequential(nn.Linear(dim_in, 3))
def forward(self, input):
outputs = self.bert(input) # "Bert Layers"
outputs = self.regressor(outputs) # LinearSequential with output of 3
return outputs
Solution 1:[1]
Only one multi-classification task can be implemented in torch. nn.Linear(dim_in, 3) represents a three-classification task, and the index corresponding to the largest number of the three output numbers is the classification result. Based on your description, it is recommended that you directly perform 6 classification, or perform two classifications for each of the three tasks, but the latter will be more troublesome.
If you are going to do a six-classification task, you need to re-label the samples, each one is represented by one-hot encoding, and then use the cross-entropy loss function.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 |
