'Error when loading a custom layer from config
I have a custom preprocessing layer which basically takes the input and applies the preprocessing function of a pretrained network coming from tensorflow.keras.applications.
class PreprocessLayer(tf.keras.layers.Layer):
def __init__(self, preprocess, **kwargs):
super(PreprocessLayer, self).__init__(**kwargs)
self.preprocess = preprocess
def call(self, inputs):
return self.preprocess(inputs)
def get_config(self):
config = super().get_config()
config.update({
'preprocess': self.preprocess
})
return config
@classmethod
def from_config(cls, config):
return cls(**config)
I am using this layer as a part of Sequential model with data augmentation.
model = Sequential()
model.add(Input(shape=(img_size_x, img_size_y, num_channels)))
model.add(data_augmentation)
model.add(PreprocessLayer(preprocess=preprocessing_function))
model.add(base_model)
...
I am saving the network and load it back again using the following functions:
def save_model(model, model_save_path):
with open(model_save_path, 'w') as file:
file.write(model.to_json())
def get_model(model_load_path, weights_load_path, custom_objects=None):
file = open(model_load_path, 'r')
json = file.read()
file.close()
model = model_from_json(json, custom_objects)
model.load_weights(weights_load_path)
return model
I get the following error:
Traceback (most recent call last):
File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3251, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-9-9d8a0d877fa9>", line 1, in <module>
new_model = get_model('/home/orkhan/Projects/FRS/Models/Style/Resnet/best_model/style_resnet',
File "<ipython-input-2-96a830f7807a>", line 19, in get_model
model = model_from_json(json, custom_objects)
File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/keras/saving/model_config.py", line 104, in model_from_json
return deserialize(config, custom_objects=custom_objects)
File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/keras/layers/serialization.py", line 207, in deserialize
return generic_utils.deserialize_keras_object(
File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/keras/utils/generic_utils.py", line 678, in deserialize_keras_object
deserialized_obj = cls.from_config(
File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/keras/engine/sequential.py", line 438, in from_config
model.add(layer)
File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/tensorflow/python/training/tracking/base.py", line 530, in _method_wrapper
result = method(self, *args, **kwargs)
File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/orkhan/Projects/FRS/venv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 699, in wrapper
raise e.ag_error_metadata.to_exception(e)
TypeError: Exception encountered when calling layer "preprocess_layer" (type PreprocessLayer).
in user code:
File "<ipython-input-6-c9ecaaa8ccf6>", line 7, in call *
return self.preprocess(inputs)
TypeError: 'str' object is not callable
Call arguments received:
• inputs=tf.Tensor(shape=(None, None, None, 3), dtype=float32)
Apparently the preprocessing function which is passed to the PreprocessLayer class is treated as 'str' not as function. How can I fix that?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
