'Error when loading a custom layer from config

I have a custom preprocessing layer which basically takes the input and applies the preprocessing function of a pretrained network coming from tensorflow.keras.applications.

class PreprocessLayer(tf.keras.layers.Layer):
    def __init__(self, preprocess, **kwargs):
        super(PreprocessLayer, self).__init__(**kwargs)
        self.preprocess = preprocess

    def call(self, inputs):
        return self.preprocess(inputs)

    def get_config(self):
        config = super().get_config()
        config.update({
            'preprocess': self.preprocess
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

I am using this layer as a part of Sequential model with data augmentation.

model = Sequential()
model.add(Input(shape=(img_size_x, img_size_y, num_channels)))
model.add(data_augmentation)
model.add(PreprocessLayer(preprocess=preprocessing_function))
model.add(base_model)
...

I am saving the network and load it back again using the following functions:

def save_model(model, model_save_path):
    with open(model_save_path, 'w') as file:
        file.write(model.to_json())

def get_model(model_load_path, weights_load_path, custom_objects=None):
    file = open(model_load_path, 'r')
    json = file.read()
    file.close()
    model = model_from_json(json, custom_objects)
    model.load_weights(weights_load_path)
    return model

I get the following error:

Traceback (most recent call last):
  File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3251, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-9-9d8a0d877fa9>", line 1, in <module>
    new_model = get_model('/home/orkhan/Projects/FRS/Models/Style/Resnet/best_model/style_resnet',
  File "<ipython-input-2-96a830f7807a>", line 19, in get_model
    model = model_from_json(json, custom_objects)
  File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/keras/saving/model_config.py", line 104, in model_from_json
    return deserialize(config, custom_objects=custom_objects)
  File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/keras/layers/serialization.py", line 207, in deserialize
    return generic_utils.deserialize_keras_object(
  File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/keras/utils/generic_utils.py", line 678, in deserialize_keras_object
    deserialized_obj = cls.from_config(
  File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/keras/engine/sequential.py", line 438, in from_config
    model.add(layer)
  File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/tensorflow/python/training/tracking/base.py", line 530, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 699, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: Exception encountered when calling layer "preprocess_layer" (type PreprocessLayer).
in user code:
    File "<ipython-input-6-c9ecaaa8ccf6>", line 7, in call  *
        return self.preprocess(inputs)
    TypeError: 'str' object is not callable
Call arguments received:
  • inputs=tf.Tensor(shape=(None, None, None, 3), dtype=float32)

Apparently the preprocessing function which is passed to the PreprocessLayer class is treated as 'str' not as function. How can I fix that?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source