'in tf-agents, the driver is generating infinite timesteps and observer not working

I am using the tf-agents for contextual bandit algorithm implementation. I am using the batched py environment (to create batched timesteps of the single environment) .

However , the environment seems to be resulting in the driver being run for infinite steps. When I replace this batched environment with some other environment (in the env argument of the driver specification), the driver (driver.run()) is running as intended.

Here is my code, can someone please tell what the issue is

class SampleEnvironment(BanditPyEnvironment):

    def __init__(self):

        observation_spec = BoundedTensorSpec(
        (2,), np.int32, minimum=[1,1], maximum=[5,2], name= 'observation')
        action_spec = BoundedTensorSpec(
            shape=(), dtype=np.int32, minimum=0, maximum=6, name='action')


        super(SampleEnvironment, self).__init__(observation_spec, action_spec)

    def _observe(self):
        self.observation=tf.cast(np.array([np.random.choice([1,2,3,4,5]),np.random.choice([1,2])]), 'int32')
        print('\n in observation',self.observation,'\n')
        return self.observation


    def _apply_action(self, action):
                
        if self.observation[0]==action:
            cost= (action+100)
        else:
            cost=self.observation[0]
            
        print("\n in apply_Action - observation is ",self.observation, ", cost is ",cost,'\n')    
        return tf.cast(cost, np.float32)         


batch_size = 4
py_envs = [SampleEnvironment() for _ in range(0, batch_size)]
batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs, multithreading = False)
tfenv = tf_py_environment.TFPyEnvironment(batched_env)

observation_spec = tfenv.observation_spec()
time_step_spec = ts.time_step_spec(observation_spec)
action_spec=tfenv.action_spec()
agent1= linear_thompson_sampling_agent.LinearThompsonSamplingAgent(time_step_spec=time_step_spec,action_spec=action_spec)

regret_metric = tf_metrics.RegretMetric(compute_optimal_reward)

batch_size = 4
steps_per_loop = 1  

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent1.policy.trajectory_spec,          
    batch_size=batch_size,
    max_length=steps_per_loop)

observer = [replay_buffer.add_batch, regret_metric]

driver = dynamic_step_driver.DynamicStepDriver(
    env = tfenv,                             #tfenv    ; environment2               
    policy = agent1.collect_policy,
    num_steps = batch_size * steps_per_loop,                        
    observers = observer)

When I replace the tfenv in the above with some other environment like the below, the driver.run() is working as expected, but with the tfenv, the driver seems to be running infinitely. The below environment has 6 actions too, I Removed the action functions for brevity of the code.

batch_size = 4 # @param

def context_sampling_fn(batch_size):
 
    def _context_sampling_fn():
        batch=[]
        batchsize=batch_size
        for i in range(batchsize):
            each=tf.cast(np.array([np.random.choice([10,20,30,40,50]),np.random.choice([1,2])]), 'int32')
            batch.append(each)
        observation=np.array(batch)
        print("in observe", observation, observation.shape)
        return observation
    
    return _context_sampling_fn

def reward_fun0(x):
    print('in reward function, of action 0',x,'\n')
    return 0

def reward_fun1(x):
    print('in reward function, of action 1',x,'\n')
    return 1


environment2 = tf_py_environment.TFPyEnvironment(
    sspe.StationaryStochasticPyEnvironment(
        context_sampling_fn(batch_size),
        [reward_fun0,reward_fun1],
        batch_size=batch_size))

Below is the rest of the code that I am using

regret_values=[]
for x in range(5):
    
    driver.run()
    
    dataset= replay_buffer.as_dataset(sample_batch_size=batch_size,num_steps=steps_per_loop)
    iterator = iter(dataset)
    
    loss = None
    for j in range(batch_size):
        trajectories, k = next(iterator)
        #print(j,trajectories)
        loss = agent1.train(experience=trajectories)  #replay_buffer.gather_all()    ;experience=trajectories

    replay_buffer.clear()
    regret_values.append(regret_metric.result())

    plt.plot(regret_values)
    plt.ylabel('Average Regret')
    plt.xlabel('Number of Iterations')


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source