'Pytorch & BERT - Predict Multiple Binary Feature

I would like to add manualy an output layer to BERT in order to predict multiple features which are binary.

For example, these outputs would answer the questions:

  • Is the text positive? 1 if yes, 0 otherwise.
  • Is this text about sports? 1 if yes, 0 if not.
  • Is this text about business? 1 if yes, 0 if not.

My first idea was to do it as a regression task, adding an output layer with 3 neurons: one for each question.

The pseudo-code:

def __init__(self):
      self.bert = CamembertModel.from_pretrained('camembert-base')
      self.regressor = nn.Sequential(nn.Linear(dim_in, 3))

def forward(self, input):
      outputs = self.bert(input) # "Bert Layers"
      outputs = self.regressor(outputs) # LinearSequential with output of 3
      return outputs

But I'll get values above 1 and below 0. So, which kind of layer could I add in order to get probability (values between 0 and 1) to deal with this problem ?

I hope my question was clear, thank you for your help.

def __init__(self):
      self.bert = CamembertModel.from_pretrained('camembert-base')
      self.regressor = nn.Sequential(nn.Linear(dim_in, 3))

def forward(self, input):
      outputs = self.bert(input) # "Bert Layers"
      outputs = self.regressor(outputs) # LinearSequential with output of 3
      return outputs


Solution 1:[1]

Only one multi-classification task can be implemented in torch. nn.Linear(dim_in, 3) represents a three-classification task, and the index corresponding to the largest number of the three output numbers is the classification result. Based on your description, it is recommended that you directly perform 6 classification, or perform two classifications for each of the three tasks, but the latter will be more troublesome.

If you are going to do a six-classification task, you need to re-label the samples, each one is represented by one-hot encoding, and then use the cross-entropy loss function.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1