'TypeError: the first argument must be callable when calling tensorflow optimizer `apply_gradients`
I hope someone can help me resolve this issue which has been driving me crazy for days. I am building something somehow inspired to this keras example. I am trying to manually calculate the gradient of a network but I can't figure out what I am doing wrong. Here is the model definition
inputs = layers.Input(shape=(state_dim,))
layer1 = layers.Dense(l1_dim, activation="relu")(inputs)
ayer2 = layers.Dense(l2_dim, activation="relu")(layer1)
action = layers.Dense(num_actions, activation="softmax")(layer2)
critic = layers.Dense(1, activation=None)(layer2)
model = keras.Model(inputs=inputs, outputs=[critic, action])
# model.compile(optimizer=keras.optimizers.Adam(learning_rate=learning_rate))
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
then I have my training loop, given set (state, action, reward, terminal, state_):
state = tf.convert_to_tensor([state], dtype=tf.float32)
state_ = tf.convert_to_tensor([state_], dtype=tf.float32)
reward = tf.convert_to_tensor(reward, dtype=tf.float32) # not fed to NN
with tf.GradientTape(persistent=True) as tape:
state_value, probs = model(state)
state_value_, _ = model(state_)
state_value = tf.squeeze(state_value)
state_value_ = tf.squeeze(state_value_)
action_probs = tfp.distributions.Categorical(probs=probs)
log_prob = action_probs.log_prob(action)
delta = reward + self.gamma * state_value_ * (1 - int(terminal)) - state_value
actor_loss = -log_prob * delta
critic_loss = delta ** 2
total_loss = actor_loss + critic_loss
gradient = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradient, model.trainable_variables))
However on my last line of code, when calling optimizer.apply_gradients I get the following error:
Traceback (most recent call last):
File "/Users/maccheroni/.virtualenvs/rl_gym/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients
self._create_all_weights(var_list)
File "/Users/maccheroni/.virtualenvs/rl_gym/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py", line 829, in _create_all_weights
self._create_hypers()
File "/Users/maccheroni/.virtualenvs/rl_gym/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py", line 977, in _create_hypers
self._hyper[name] = self.add_weight(
File "/Users/maccheroni/.virtualenvs/rl_gym/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1192, in add_weight
variable = self._add_variable_with_custom_getter(
File "/Users/maccheroni/.virtualenvs/rl_gym/lib/python3.9/site-packages/tensorflow/python/training/tracking/base.py", line 816, in _add_variable_with_custom_getter
new_variable = getter(
File "/Users/maccheroni/.virtualenvs/rl_gym/lib/python3.9/site-packages/keras/engine/base_layer_utils.py", line 106, in make_variable
init_val = functools.partial(initializer, shape, dtype=dtype)
TypeError: the first argument must be callable
and I really don't understand why, because I have read so many tutorials, followed so many examples and they seem all to use this function in this way.
Solution 1:[1]
I also had the same error and found the solution. In my case, the initialization of the optimizer:
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
was using the variable learning_rate which was None. initializting with a number or simply:
optimizer = keras.optimizers.Adam()
solved the issue.
In your case, it is not clear, what learning_rate is but you might check this out.
Solution 2:[2]
First, you are using an "offline" version of the Java buildpack. This means that the buildpack is limited to using dependencies that are bundled with the buildpack. You're using version 4.46, so you can see the list of what is bundled with it here.
https://github.com/cloudfoundry/java-buildpack/releases/tag/v4.46
This includes Java 1.8.0_312, 11.0.13_8 and 17.0.1_12.
Second, you are requesting Java 15 to be installed. The buildpack is failing because it does not have access to Java 15, as explained above.
The default version of Java that the buildpack will use is Java 8. Since it's picking something else here, that must be due to user input. Check your env variables, i.e. cf env, and look for ones starting with JBP_CONFIG_*. There is likely one set that is configuring the buildpack to use Java 15.
See docs for more details on how the Java buildpack can be configured.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Lukas Hebing |
| Solution 2 | Daniel Mikusa |
