'Customize model.fit() in Tensorflow Keras - TypeError: TypeError: compile() got an unexpected keyword argument 'loss'

I am trying to implement my own weight training algorithm for Knowledge Distillation by modifying the model. fit() in Keras. I have read link

It works perfectly fine when input data is Images but won't work when input data is text embeddings P.S I have removed all network details to show you how it works with images but not text embeddings The model that I have created are for images:

# Create the teacher
teacher = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="teacher",
)

# Create the student
student = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="student",
)

# Clone student for later comparison
student_scratch = keras.models.clone_model(student)

The custom Distiller() class, overrides the Model methods train_step, test_step, and compile().

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super(Distiller, self).__init__()
        self.teacher = teacher
        self.student = student

    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        """ Configure the distiller.

        Args:
            optimizer: Keras optimizer for the student weights
            metrics: Keras metrics for evaluation
            student_loss_fn: Loss function of difference between student
                predictions and ground-truth
            distillation_loss_fn: Loss function of difference between soft
                student predictions and soft teacher predictions
            alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
            temperature: Temperature for softening probability distributions.
                Larger temperature gives softer distributions.
        """
        super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss

        # Compute gradients
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update the metrics configured in `compile()`.
        self.compiled_metrics.update_state(y, student_predictions)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update(
            {"student_loss": student_loss, "distillation_loss": distillation_loss}
        )
        return results

    def test_step(self, data):
        # Unpack the data
        x, y = data

        # Compute predictions
        y_prediction = self.student(x, training=False)

        # Calculate the loss
        student_loss = self.student_loss_fn(y, y_prediction)

        # Update the metrics.
        self.compiled_metrics.update_state(y, y_prediction)

        # Return a dict of performance
        results = {m.name: m.result() for m in self.metrics}
        results.update({"student_loss": student_loss})
        return results

train and test dataset for images.

batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

Now I can successfully train both teacher and distiller models for images

# Train teacher as usual
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=1)
teacher.evaluate(x_test, y_test)
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# Distill teacher to student
distiller.fit(x_train, y_train, epochs=1)

# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)

Now if I do the same for text embeddings data set with shapes x_train == (19091, 512) y_train == (19091,)

# Create the teacher
teacher = keras.Sequential(
    [
        keras.Input(shape=(512, )),
        layers.Flatten(),
        layers.Dense(13),
    ],
    name="teacher",
)

# Create the student
student = keras.Sequential(
    [
        keras.Input(shape=(512, )),
        layers.Flatten(),
        layers.Dense(13),
    ],
    name="student",
)

# Clone student for later comparison
student_scratch = keras.models.clone_model(student)

Now I can successfully train teacher model

# Train teacher as usual
teacher.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=1)
teacher.evaluate(x_test, y_test)

but when I train distiller model for text embeddings I'm getting the error

# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.1,
    temperature=10,
)

# Distill teacher to student
distiller.fit(x_train, y_train, epochs=1)

# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)

I'm getting the following error specifically it occurs on distiller.fit(x_train, y_train, epochs=1)

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/yz/l3346p717ljdvzd6774hq1nc0000gn/T/ipykernel_30052/1479990457.py in <module>
     11 
     12 # Distill teacher to student
---> 13 distiller.fit(x_train, y_train, epochs=1)
     14 
     15 # Evaluate student on test dataset

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training_v1.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    794         max_queue_size=max_queue_size,
    795         workers=workers,
--> 796         use_multiprocessing=use_multiprocessing)
    797 
    798   def evaluate(self,

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training_arrays_v1.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs)
    623         steps=steps_per_epoch,
    624         validation_split=validation_split,
--> 625         shuffle=shuffle)
    626 
    627     if validation_data:

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training_v1.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
   2310     is_compile_called = False
   2311     if not self._is_compiled and self.optimizer:
-> 2312       self._compile_from_inputs(all_inputs, y_input, x, y)
   2313       is_compile_called = True
   2314 

/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training_v1.py in _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target)
   2569         sample_weight_mode=self.sample_weight_mode,
   2570         run_eagerly=self.run_eagerly,
-> 2571         experimental_run_tf_function=self._experimental_run_tf_function)
   2572 
   2573   # TODO(omalleyt): Consider changing to a more descriptive function name.

TypeError: compile() got an unexpected keyword argument 'loss'

Have anyone run into the same error before or have any idea what might be wrong here? If more information is needed to help me, please let me know.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source