'model.fit triggers tf.function retracing
I'm not entirely sure what code will be useful for this, so bear with me if there's a lot of it.
I've been working to train a population of agents in a single generation as part of my attempts to develop a deep reinforcement genetic algorithm. As part of that, I want to display the agents each taking an action.
So I have my agents in a list. I run each agent's unique model to get the action they're performing. They take an action. They then train their short term memory to try and improve themselves just a bit.
When training their short term memory, I obviously use model.fit(). But when I do, I get the following message:
WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_train_function.<locals>.train_function at 0x000001F1ED719A60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
1/1 [==============================] - 0s 231ms/step - loss: 1.5837e-05 - root_mean_squared_error: 0.0040
Print statement debugging tells me that the model.fit is definitely where the problem is. But when I don't understand what the issue would be. When I process a single agent (running a single agent until they die) instead of running all the agents at once, the model.fit works just fine.
Now, as promised, here's the code. I don't know what's useful, so I'm giving a lot of it:
The model:
class QNet():
def linear_QNet(model_path="", input_size=11, hidden_sizes=[256], output_size=3, learning_rate=0.001):
'''
Build the model, given an input size, list of hidden layer sizes,
output size, and learning rate.
Either load a trained model or create a new one.
'''
# If you want to load a trained model...
if (model_path):
model = load_model(model_path)
# If you want to build a new model
else:
model = Sequential()
model.add(Input(shape=(input_size,)))
for layer_size in hidden_sizes:
model.add(Dense(units=layer_size, activation="relu"))
model.add(Dense(units=output_size))
model.compile(optimizer=Adam(learning_rate=learning_rate), loss="mean_squared_error", metrics=[RootMeanSquaredError()])
return model
QTrainer (the model.fit at the bottom of this is where I get the warning):
class QTrainer():
def __init__(self, model, gamma=0.9):
self.model = model
self.gamma = gamma
def train_step(self, state, action, reward, next_state, done):
'''
The training function. Handles both long
and short term memory training.
'''
# Should be in the state (n, x)
state = np.array(state, dtype=float)
next_state = np.array(next_state, dtype=float)
action = np.array(action, dtype=int)
reward = np.array(reward, dtype=float)
if len(state.shape) == 1:
# If it is in this state: (1, x)
# Then convert it to (n, x)
state = np.expand_dims(state, 0)
next_state = np.expand_dims(next_state, 0)
action = np.expand_dims(action, 0)
reward = np.expand_dims(reward, 0)
done = (done, )
# Predicted Q values with current state
pred = self.model(state)
target = np.array(pred, copy=True)
for i in range(len(done)):
# Reshape reward, if needed
if len(reward[i].shape) == 1:
_reward = np.expand_dims(reward[i], 0)
else:
_reward = reward[i]
# Check to see if the formula should be used
if done[i]:
Q_new = _reward
else:
# Reshape next_state, if needed
if len(next_state[i].shape) == 1:
_next_state = np.expand_dims(next_state[i], 0)
else:
_next_state = next_state[i]
# Formula
Q_new = _reward + self.gamma * np.max(self.model(_next_state))
# Reshape action, if needed
if len(action[i].shape) == 1:
_action = np.expand_dims(action[i], 0)
else:
_action = action[i]
# Set target
target[i][np.argmax(_action).item()] = Q_new
# Train the model
self.model.fit(x=state,
y=target,
epochs=1,
batch_size=len(done),
verbose=0)
The code where I run my population of agents:
def _run_pop_of_agents(self):
''' Run a population of agents with an entire generation being processed at once. '''
agents_left = len(self.agents) # loops ends when all agents die
dones = [False for i in range(agents_left)]
self.all_episodes += 1
self.gen_episodes += 1
while agents_left:
# Check for escape
for event in pyg_get():
# Check for exiting out of window
if event.type == pyg_QUIT:
self.quit = True
elif event.type == pyg_KEYDOWN:
if event.key == pyg_K_ESCAPE:
self.quit = True
if self.quit: break
# Process each agent's step
for agent_num, agent in enumerate(self.agents):
# Check to see if this agent already died
if dones[agent_num]: continue
# Get old state
state_old = agent.get_state(self.game)
# Get move
final_move = agent.get_action(state_old)
# Perform move and get new state
reward, done, score = self.game.play_step(final_move)
state_new = agent.get_state(self.game)
# Train short memory
agent.train_short_memory(state_old, final_move, reward, state_new, done)
# Remember
agent.remember(state_old, final_move, reward, state_new, done)
For comparison, here's the code where I run individual agents. It works just fine.
def _run_episode(self, agent, single_agent=False):
''' Run an episode of the game. '''
run = True
while run:
# Check for escape
for event in pyg_get():
# Check for exiting out of window
if event.type == pyg_QUIT:
self.quit = True
elif event.type == pyg_KEYDOWN:
if event.key == pyg_K_ESCAPE:
self.quit = True
if self.quit: break
# Get old state
state_old = agent.get_state(self.game)
# Get move
final_move = agent.get_action(state_old)
# Perform move and get new state
reward, done, score = self.game.play_step(final_move)
state_new = agent.get_state(self.game)
# Train short memory
agent.train_short_memory(state_old, final_move, reward, state_new, done)
# Remember
agent.remember(state_old, final_move, reward, state_new, done)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
