'`logits` and `labels` must have the same shape, received ((None, 512, 768) vs (None, 1)) when using transformers

I get the next error when im trying to fine tuning a bert model to predict sentiment analysis.

Im using as input: X-A list of strings that contains tweets y-a numeric list (0 - negative, 1 - positive)

I am trying to fine tuning a bert model to predict sentiment analysis but i always get the same error in logits and labels when im trying to fit the model. I load a pretrained model and then build the dataset but when i am trying to fit it, it is impossible.

The text used as input is a list of strings made of tweets and the labels used as input are a list of categories (negative and positive) but transformed to 0 and 1.

from sklearn.preprocessing import MultiLabelBinarizer

#LOAD MODEL

hugging_face_model = 'distilbert-base-uncased-finetuned-sst-2-english'
batches = 32
epochs = 1 

tokenizer = BertTokenizer.from_pretrained(hugging_face_model)
model = TFBertModel.from_pretrained(hugging_face_model, num_labels=2)

#PREPARE THE DATASET

#create a list of strings (tweets)


lst = list(X_train_lower['lower_text'].values) 
encoded_input  = tokenizer(lst, truncation=True, padding=True, return_tensors='tf')

y_train['sentimentNumber'] = y_train['sentiment'].replace({'negative': 0, 'positive': 1})
label_list = list(y_train['sentimentNumber'].values) 

#CREATE DATASET

train_dataset = tf.data.Dataset.from_tensor_slices((dict(encoded_input), label_list))

#COMPILE AND FIT THE MODEL

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5), loss=BinaryCrossentropy(from_logits=True),metrics=["accuracy"])
model.fit(train_dataset.shuffle(len(df)).batch(batches),epochs=epochs,batch_size=batches) ```




ValueError                                Traceback (most recent call last)
<ipython-input-158-e5b63f982311> in <module>()
----> 1 model.fit(train_dataset.shuffle(len(df)).batch(batches),epochs=epochs,batch_size=batches)

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
   1145           except Exception as e:  # pylint:disable=broad-except
   1146             if hasattr(e, "ag_error_metadata"):
-> 1147               raise e.ag_error_metadata.to_exception(e)
   1148             else:
   1149               raise

ValueError: in user code:

    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1021, in train_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1010, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1000, in run_step  **
        outputs = model.train_step(data)
    File "/usr/local/lib/python3.7/dist-packages/transformers/modeling_tf_utils.py", line 1000, in train_step
        loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
    File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 201, in __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 141, in __call__
        losses = call_fn(y_true, y_pred)
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 245, in call  **
        return ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 1932, in binary_crossentropy
        backend.binary_crossentropy(y_true, y_pred, from_logits=from_logits),
    File "/usr/local/lib/python3.7/dist-packages/keras/backend.py", line 5247, in binary_crossentropy
        return tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)

    ValueError: `logits` and `labels` must have the same shape, received ((None, 512, 768) vs (None, 1)).


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source