'model.fit triggers tf.function retracing

I'm not entirely sure what code will be useful for this, so bear with me if there's a lot of it.

I've been working to train a population of agents in a single generation as part of my attempts to develop a deep reinforcement genetic algorithm. As part of that, I want to display the agents each taking an action.

So I have my agents in a list. I run each agent's unique model to get the action they're performing. They take an action. They then train their short term memory to try and improve themselves just a bit.

When training their short term memory, I obviously use model.fit(). But when I do, I get the following message:

WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_train_function.<locals>.train_function at 0x000001F1ED719A60> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
1/1 [==============================] - 0s 231ms/step - loss: 1.5837e-05 - root_mean_squared_error: 0.0040

Print statement debugging tells me that the model.fit is definitely where the problem is. But when I don't understand what the issue would be. When I process a single agent (running a single agent until they die) instead of running all the agents at once, the model.fit works just fine.

Now, as promised, here's the code. I don't know what's useful, so I'm giving a lot of it:

The model:

class QNet():
    def linear_QNet(model_path="", input_size=11, hidden_sizes=[256], output_size=3, learning_rate=0.001):
        '''
        Build the model, given an input size, list of hidden layer sizes,
        output size, and learning rate.
        
        Either load a trained model or create a new one.
        '''
        # If you want to load a trained model...
        if (model_path):
            model = load_model(model_path)

        # If you want to build a new model
        else:
            model = Sequential()
            model.add(Input(shape=(input_size,)))
            for layer_size in hidden_sizes:
                model.add(Dense(units=layer_size, activation="relu"))
            model.add(Dense(units=output_size))
            model.compile(optimizer=Adam(learning_rate=learning_rate), loss="mean_squared_error", metrics=[RootMeanSquaredError()])
        return model

QTrainer (the model.fit at the bottom of this is where I get the warning):

class QTrainer():
    def __init__(self, model, gamma=0.9):
        self.model = model
        self.gamma = gamma


    def train_step(self, state, action, reward, next_state, done):
        '''
        The training function. Handles both long
        and short term memory training.
        '''
        # Should be in the state (n, x)
        state = np.array(state, dtype=float)
        next_state = np.array(next_state, dtype=float)
        action = np.array(action, dtype=int)
        reward = np.array(reward, dtype=float)

        if len(state.shape) == 1:
            # If it is in this state: (1, x)
            # Then convert it to (n, x)
            state = np.expand_dims(state, 0)
            next_state = np.expand_dims(next_state, 0)
            action = np.expand_dims(action, 0)
            reward = np.expand_dims(reward, 0)
            done = (done, )

        # Predicted Q values with current state
        pred = self.model(state)

        target = np.array(pred, copy=True)  
        for i in range(len(done)):
            # Reshape reward, if needed
            if len(reward[i].shape) == 1:
                _reward = np.expand_dims(reward[i], 0)
            else:
                _reward = reward[i]

            # Check to see if the formula should be used
            if done[i]:
                Q_new = _reward
            else:
                # Reshape next_state, if needed
                if len(next_state[i].shape) == 1:
                    _next_state = np.expand_dims(next_state[i], 0)
                else:
                    _next_state = next_state[i]

                # Formula
                Q_new = _reward + self.gamma * np.max(self.model(_next_state))

            # Reshape action, if needed
            if len(action[i].shape) == 1:
                _action = np.expand_dims(action[i], 0)
            else:
                _action = action[i]

            # Set target
            target[i][np.argmax(_action).item()] = Q_new

        # Train the model
        self.model.fit(x=state,
                       y=target,
                       epochs=1,
                       batch_size=len(done),
                       verbose=0)

The code where I run my population of agents:

def _run_pop_of_agents(self):
    ''' Run a population of agents with an entire generation being processed at once. '''
    agents_left = len(self.agents) # loops ends when all agents die
    dones = [False for i in range(agents_left)]
    self.all_episodes += 1
    self.gen_episodes += 1
    while agents_left:
        # Check for escape
        for event in pyg_get():
            # Check for exiting out of window
            if event.type == pyg_QUIT:
                self.quit = True
            elif event.type == pyg_KEYDOWN:
                if event.key == pyg_K_ESCAPE:
                    self.quit = True
        if self.quit: break

        # Process each agent's step
        for agent_num, agent in enumerate(self.agents):
            # Check to see if this agent already died
            if dones[agent_num]: continue

            # Get old state
            state_old = agent.get_state(self.game)

            # Get move
            final_move = agent.get_action(state_old)

            # Perform move and get new state
            reward, done, score = self.game.play_step(final_move)
            state_new = agent.get_state(self.game)

            # Train short memory
            agent.train_short_memory(state_old, final_move, reward, state_new, done)

            # Remember
            agent.remember(state_old, final_move, reward, state_new, done)

For comparison, here's the code where I run individual agents. It works just fine.

def _run_episode(self, agent, single_agent=False):
        ''' Run an episode of the game. '''
        run = True
        while run:
            # Check for escape
            for event in pyg_get():
                # Check for exiting out of window
                if event.type == pyg_QUIT:
                    self.quit = True
                elif event.type == pyg_KEYDOWN:
                    if event.key == pyg_K_ESCAPE:
                        self.quit = True
            if self.quit: break

            # Get old state
            state_old = agent.get_state(self.game)

            # Get move
            final_move = agent.get_action(state_old)

            # Perform move and get new state
            reward, done, score = self.game.play_step(final_move)
            state_new = agent.get_state(self.game)

            # Train short memory
            agent.train_short_memory(state_old, final_move, reward, state_new, done)

            # Remember
            agent.remember(state_old, final_move, reward, state_new, done)


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source