'How to access model inside custom Keras loss function?
def custom_correlation_loss(input_data, model):
def custom_loss(y_true, y_pred):
print(input_data)
model_prediction = model.predict(input_data)
print(model_prediction)
mse_loss = keras.losses.MSE(y_true, y_pred)
return mse_loss
return custom_loss
# Some model architecture "model"
model.load_weights('/content/drive/MyDrive/v2a_with_only_sex_combined_10k_model_survey_to_health_outcomes_LSTM.h5', by_name=True, skip_mismatch=True)
model.compile(optimizer = adam, loss = custom_correlation_loss(lambda_x), run_eagerly=True)
I actually wanted to compute some information with model inside custom loss function. When I try to access the model. I'm getting "Method requires being in cross-replica context, use get_replica_context().merge_call()"
Can someone let me know what the issue is or How do I access the model inside a custom loss function. So, the model I access is updated after every batch iteration.
RuntimeError: in user code:
File "<ipython-input-64-7b5fc945c2b5>", line 102, in custom_loss *
x = callback.on_train_batch_end(batch=0, logs={})
File "<ipython-input-42-3132770117f2>", line 42, in on_train_batch_end *
y_pred_batch = self.model.predict(x_batch_data.numpy())
File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler **
raise e.with_traceback(filtered_tb) from None
RuntimeError: Method requires being in cross-replica context, use get_replica_context().merge_call()
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
