'ValueError: Unknown metric function: top_2_accuracy. Please ensure this object is passed to the `custom_objects` argument
I'm working on a CNN classification project, and I've used the top 2 Accuracy (top_k_categorical_accuracy) as the accuracy matrix for training. the function in the model notebook is:
from tensorflow.keras.metrics import top_k_categorical_accuracy
def top_2_accuracy(y_true, y_pred):
return top_k_categorical_accuracy(y_true, y_pred, k=2)
then I used it as
model.compile(optimizer = Adam(lr = 4e-3),
loss='categorical_crossentropy',
metrics=['accuracy', top_2_accuracy])
Now I need to load the model to use it in an application I tried to use this code:
model = tf.keras.models.load_model('model.h5')
but an error occurred while loading
Unknown metric function: top_2_accuracy. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
then I've googled and tried this too:
model = tf.keras.models.load_model('model.h5', custom_objects={'top_k_categorical_accuracy(y_true, y_pred, k=2)':top_2_accuracy})
but it generated another error
model = tf.keras.models.load_model('model.h5', custom_objects={'top_k_categorical_accuracy(y_true, y_pred, k=2)':top_2_accuracy})
NameError: name 'top_2_accuracy' is not defined
How can I solve this??
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
