'How to use keras api to get custom objects
I implemented a custom constraints, and would like to let keras handle the serialize and deserialize procedure. However I failed to run the commented lines as below:
import tensorflow as tf
from tensorflow.keras import backend
@tf.keras.utils.register_keras_serializable(package='mypackage', name='UnitL1Norm')
class UnitL1Norm(tf.keras.constraints.Constraint):
def __init__(self, axis=0):
self.axis = axis
def __call__(self, w):
return w / (
backend.epsilon + tf.reduce_sum(
backend.abs(w), axis=self.axis, keepdims=True
)
)
def get_config(self):
return {'axis': self.axis}
# ValueError: Unknown constraint: UnitL1Norm. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
#a = tf.keras.constraints.get(dict(class_name='mypackage.UnitL1Norm', config=dict(axis=1)))
a = tf.keras.layers.Dense(3, kernel_constraint=UnitL1Norm(1))
I have read through this section but still have no idea how to achieve my goal, could anyone hint me a little bit? Thanks!
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
