'Cosine_proximity loss function problem when using model.fit
I'm trying to run this RNN model for that i want to use the cosine_proximity loss function, i should say that i'm coding using google colabthe code s.o please help me figure the problem. here is the source code of the RNN model:
import tensorflow as tf
from tensorflow import keras
from keras import Sequential
from keras.layers import LSTM
from keras.layers import Dropout
model = Sequential()
model.add(LSTM(units=512, input_shape = X_train.shape[1:],activation='relu',return_sequences= True))
model.add(Dropout(0.2)
model.add(LSTM(units=128,activation='relu',return_sequences= True))
model.add(Dropout(0.2)
model.add(LSTM(units=64,activation='relu',return_sequences=True))
model.add(Dropout(0.2)
model.add(Dense(units=10,activation='relu'))
model.add(Dropout(0.2)
model.compile(loss="cosine_proximity", optimizer='sgd', metrics = ['accuracy'])
print(model.summary())
model.fit(X_train, y_train, epochs=1, verbose=1)
and this is what i get when i run the cell
Model: "sequential_9"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_27 (LSTM) (None, 523, 512) 1052672
lstm_28 (LSTM) (None, 523, 128) 328192
lstm_29 (LSTM) (None, 523, 64) 49408
=================================================================
Total params: 1,430,272
Trainable params: 1,430,272
Non-trainable params: 0
_________________________________________________________________
None
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-44-fc8e0b2a4cd4> in <module>()
13 print(model.summary())
14
---> 15 model.fit(X_train, y_train, epochs=1, verbose=1)
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in autograph_handler(*args, **kwargs)
1145 except Exception as e: # pylint:disable=broad-except
1146 if hasattr(e, "ag_error_metadata"):
-> 1147 raise e.ag_error_metadata.to_exception(e)
1148 else:
1149 raise
ValueError: in user code:
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1021, in train_function *
return step_function(self, iterator)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1010, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 1000, in run_step **
outputs = model.train_step(data)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 860, in train_step
loss = self.compute_loss(x, y, y_pred, sample_weight)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/training.py", line 919, in compute_loss
y, y_pred, sample_weight, regularization_losses=self.losses)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 184, in __call__
self.build(y_pred)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 133, in build
self._losses = tf.nest.map_structure(self._get_loss_object, self._losses)
File "/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py", line 272, in _get_loss_object
loss = losses_mod.get(loss)
File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 2369, in get
return deserialize(identifier)
File "/usr/local/lib/python3.7/dist-packages/keras/losses.py", line 2328, in deserialize
printable_module_name='loss function')
File "/usr/local/lib/python3.7/dist-packages/keras/utils/generic_utils.py", line 710, in deserialize_keras_object
f'Unknown {printable_module_name}: {object_name}. Please ensure '
ValueError: Unknown loss function: cosine_proximity. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
any help pls in order to fix this problem ???
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
