'How to access model inside custom Keras loss function?

def custom_correlation_loss(input_data, model):
  def custom_loss(y_true, y_pred):
    print(input_data)
    model_prediction = model.predict(input_data)
    print(model_prediction)
    mse_loss = keras.losses.MSE(y_true, y_pred)
    return mse_loss
  return custom_loss

# Some model architecture "model"

model.load_weights('/content/drive/MyDrive/v2a_with_only_sex_combined_10k_model_survey_to_health_outcomes_LSTM.h5', by_name=True, skip_mismatch=True)
model.compile(optimizer = adam, loss = custom_correlation_loss(lambda_x), run_eagerly=True)

I actually wanted to compute some information with model inside custom loss function. When I try to access the model. I'm getting "Method requires being in cross-replica context, use get_replica_context().merge_call()"

Can someone let me know what the issue is or How do I access the model inside a custom loss function. So, the model I access is updated after every batch iteration.

RuntimeError: in user code:

File "<ipython-input-64-7b5fc945c2b5>", line 102, in custom_loss  *
    x = callback.on_train_batch_end(batch=0, logs={})
File "<ipython-input-42-3132770117f2>", line 42, in on_train_batch_end  *
    y_pred_batch = self.model.predict(x_batch_data.numpy())
File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler  **
    raise e.with_traceback(filtered_tb) from None

RuntimeError: Method requires being in cross-replica context, use get_replica_context().merge_call()


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source