'Customize model.fit() in Tensorflow Keras - TypeError: TypeError: compile() got an unexpected keyword argument 'loss'
I am trying to implement my own weight training algorithm for Knowledge Distillation by modifying the model. fit() in Keras. I have read link
It works perfectly fine when input data is Images but won't work when input data is text embeddings P.S I have removed all network details to show you how it works with images but not text embeddings The model that I have created are for images:
# Create the teacher
teacher = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Flatten(),
layers.Dense(10),
],
name="teacher",
)
# Create the student
student = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Flatten(),
layers.Dense(10),
],
name="student",
)
# Clone student for later comparison
student_scratch = keras.models.clone_model(student)
The custom Distiller() class, overrides the Model methods train_step, test_step, and compile().
class Distiller(keras.Model):
def __init__(self, student, teacher):
super(Distiller, self).__init__()
self.teacher = teacher
self.student = student
def compile(
self,
optimizer,
metrics,
student_loss_fn,
distillation_loss_fn,
alpha=0.1,
temperature=3,
):
""" Configure the distiller.
Args:
optimizer: Keras optimizer for the student weights
metrics: Keras metrics for evaluation
student_loss_fn: Loss function of difference between student
predictions and ground-truth
distillation_loss_fn: Loss function of difference between soft
student predictions and soft teacher predictions
alpha: weight to student_loss_fn and 1-alpha to distillation_loss_fn
temperature: Temperature for softening probability distributions.
Larger temperature gives softer distributions.
"""
super(Distiller, self).compile(optimizer=optimizer, metrics=metrics)
self.student_loss_fn = student_loss_fn
self.distillation_loss_fn = distillation_loss_fn
self.alpha = alpha
self.temperature = temperature
def train_step(self, data):
# Unpack data
x, y = data
# Forward pass of teacher
teacher_predictions = self.teacher(x, training=False)
with tf.GradientTape() as tape:
# Forward pass of student
student_predictions = self.student(x, training=True)
# Compute losses
student_loss = self.student_loss_fn(y, student_predictions)
distillation_loss = self.distillation_loss_fn(
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
tf.nn.softmax(student_predictions / self.temperature, axis=1),
)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
# Compute gradients
trainable_vars = self.student.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update the metrics configured in `compile()`.
self.compiled_metrics.update_state(y, student_predictions)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update(
{"student_loss": student_loss, "distillation_loss": distillation_loss}
)
return results
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_prediction = self.student(x, training=False)
# Calculate the loss
student_loss = self.student_loss_fn(y, y_prediction)
# Update the metrics.
self.compiled_metrics.update_state(y, y_prediction)
# Return a dict of performance
results = {m.name: m.result() for m in self.metrics}
results.update({"student_loss": student_loss})
return results
train and test dataset for images.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Normalize data
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
Now I can successfully train both teacher and distiller models for images
# Train teacher as usual
teacher.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=1)
teacher.evaluate(x_test, y_test)
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
optimizer=keras.optimizers.Adam(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
distillation_loss_fn=keras.losses.KLDivergence(),
alpha=0.1,
temperature=10,
)
# Distill teacher to student
distiller.fit(x_train, y_train, epochs=1)
# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
Now if I do the same for text embeddings data set with shapes x_train == (19091, 512) y_train == (19091,)
# Create the teacher
teacher = keras.Sequential(
[
keras.Input(shape=(512, )),
layers.Flatten(),
layers.Dense(13),
],
name="teacher",
)
# Create the student
student = keras.Sequential(
[
keras.Input(shape=(512, )),
layers.Flatten(),
layers.Dense(13),
],
name="student",
)
# Clone student for later comparison
student_scratch = keras.models.clone_model(student)
Now I can successfully train teacher model
# Train teacher as usual
teacher.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
# Train and evaluate teacher on data.
teacher.fit(x_train, y_train, epochs=1)
teacher.evaluate(x_test, y_test)
but when I train distiller model for text embeddings I'm getting the error
# Initialize and compile distiller
distiller = Distiller(student=student, teacher=teacher)
distiller.compile(
optimizer=keras.optimizers.Adam(),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
distillation_loss_fn=keras.losses.KLDivergence(),
alpha=0.1,
temperature=10,
)
# Distill teacher to student
distiller.fit(x_train, y_train, epochs=1)
# Evaluate student on test dataset
distiller.evaluate(x_test, y_test)
I'm getting the following error specifically it occurs on distiller.fit(x_train, y_train, epochs=1)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/var/folders/yz/l3346p717ljdvzd6774hq1nc0000gn/T/ipykernel_30052/1479990457.py in <module>
11
12 # Distill teacher to student
---> 13 distiller.fit(x_train, y_train, epochs=1)
14
15 # Evaluate student on test dataset
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training_v1.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
794 max_queue_size=max_queue_size,
795 workers=workers,
--> 796 use_multiprocessing=use_multiprocessing)
797
798 def evaluate(self,
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training_arrays_v1.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs)
623 steps=steps_per_epoch,
624 validation_split=validation_split,
--> 625 shuffle=shuffle)
626
627 if validation_data:
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training_v1.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
2310 is_compile_called = False
2311 if not self._is_compiled and self.optimizer:
-> 2312 self._compile_from_inputs(all_inputs, y_input, x, y)
2313 is_compile_called = True
2314
/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/keras/engine/training_v1.py in _compile_from_inputs(self, all_inputs, target, orig_inputs, orig_target)
2569 sample_weight_mode=self.sample_weight_mode,
2570 run_eagerly=self.run_eagerly,
-> 2571 experimental_run_tf_function=self._experimental_run_tf_function)
2572
2573 # TODO(omalleyt): Consider changing to a more descriptive function name.
TypeError: compile() got an unexpected keyword argument 'loss'
Have anyone run into the same error before or have any idea what might be wrong here? If more information is needed to help me, please let me know.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
