'ValueError: Target size (torch.Size([32, 16])) must be the same as input size (torch.Size([32, 6]))

Trying to use bert language model to train a multilabel classification problem. There are text_input column and 16 categories columns with 1 in the column if text_input belong to that category.

class BERTClass(torch.nn.Module):
    def __init__(self):
        super(BERTClass, self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased', return_dict=True)
        self.dropout = torch.nn.Dropout(0.3)
        self.linear = torch.nn.Linear(768, 6)
    
    def forward(self, input_ids, attn_mask, token_type_ids):
        output = self.bert_model(
            input_ids, 
            attention_mask=attn_mask, 
            token_type_ids=token_type_ids
        )
        output_dropout = self.dropout(output.pooler_output)
        output = self.linear(output_dropout)
        return output

model = BERTClass()

ids = data['input_ids'].to(device, dtype = torch.long)
mask = data['attention_mask'].to(device, dtype = torch.long)
token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            
targets = data['targets'].to(device, dtype = torch.float)
            
outputs = model(ids, mask, token_type_ids)

data was encoded through

tokenizer.encode_plus(
            title,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

checked targets and outputs shape

<class 'torch.Tensor'>
torch.Size([32, 6])                 targets.shape
<class 'torch.Tensor'>
torch.Size([32, 16])                outputs.shape

I think the problem might be caused be self.linear = torch.nn.Linear(768, 6) in BERTClass(torch.nn.Module) but not sure how to fix this or maybe I don't need this linear layer for classification. Any idea?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source