'ValueError: `class_weight` is only supported for Models with a single output

I'm getting the below error while using class weights in the model.fit in tensorflow version 2.7.0

ValueError                                Traceback (most recent call last)
<ipython-input-21-27678ea5cbb9> in <module>
      1 # Fit the model
----> 2 history = model.fit(
      3     x={'input_ids': x['input_ids'], 'attention_mask': x['attention_mask']},
      4     y={'category': y_category },
      5     validation_split=0.2,

~\Anaconda3\lib\site-packages\keras\utils\traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

~\Anaconda3\lib\site-packages\keras\engine\data_adapter.py in _class_weights_map_fn(*data)
   1434 
   1435     if tf.nest.is_nested(y):
-> 1436       raise ValueError(
   1437           "`class_weight` is only supported for Models with a single output.")
   1438 

ValueError: `class_weight` is only supported for Models with a single output.

My fit function is:

# Fit the model
history = model.fit(
    x={'input_ids': x['input_ids'], 'attention_mask': x['attention_mask']},
    y={'category ': y_category },
    validation_split=0.2,
    class_weight=d_class_weights,
    batch_size=6,
    epochs=10)

and I managed to get d_class_weights as:

from sklearn.utils.class_weight import compute_class_weight

y_category = np.argmax(y_category , axis=1)
class_weights = compute_class_weight('balanced', np.unique(y_category ), y_category )
d_class_weights = dict(enumerate(class_weights))
d_class_weights

If I run, without class_weights it's working fine, however the dataset is very unbalanced and hence this is needed.

Model:

Model: "BERT_MultiLabel_MultiClass"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 attention_mask (InputLayer)    [(None, 300)]        0           []                               
                                                                                                  
 input_ids (InputLayer)         [(None, 300)]        0           []                               
                                                                                                  
 bert (TFBertMainLayer)         TFBaseModelOutputWi  109482240   ['attention_mask[0][0]',         
                                thPooling(last_hidd               'input_ids[0][0]']              
                                en_state=(None, 300                                               
                                , 768),                                                           
                                 pooler_output=(Non                                               
                                e, 768),                                                          
                                 hidden_states=None                                               
                                , attentions=None)                                                
                                                                                                  
 pooled_output (Dropout)        (None, 768)          0           ['bert[0][1]']                   
                                                                                                  
     category (Dense)           (None, 32)           24608       ['pooled_output[0][0]']          
                                                                                                  
==================================================================================================
Total params: 109,506,848
Trainable params: 109,506,848
Non-trainable params: 0
__________________________________________________________________________________________________


Solution 1:[1]

I managed to work around this error by downgrading tensorflow from 2.8.0 to 2.1.0 which meant I also had to downgrade python to 3.7 and transformers package to to 3.1.0

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 nikviz