'How to use keras api to get custom objects

I implemented a custom constraints, and would like to let keras handle the serialize and deserialize procedure. However I failed to run the commented lines as below:

import tensorflow as tf
from tensorflow.keras import backend

@tf.keras.utils.register_keras_serializable(package='mypackage', name='UnitL1Norm')
class UnitL1Norm(tf.keras.constraints.Constraint):

    def __init__(self, axis=0):
        self.axis = axis

    def __call__(self, w):
        return w / (
            backend.epsilon + tf.reduce_sum(
                backend.abs(w), axis=self.axis, keepdims=True
            )
        )

    def get_config(self):
        return {'axis': self.axis}

# ValueError: Unknown constraint: UnitL1Norm. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
#a = tf.keras.constraints.get(dict(class_name='mypackage.UnitL1Norm', config=dict(axis=1)))
a = tf.keras.layers.Dense(3, kernel_constraint=UnitL1Norm(1))

I have read through this section but still have no idea how to achieve my goal, could anyone hint me a little bit? Thanks!



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source