'Error while defining observation space in gym custom environment

I am working on a reinforcement algorithm, I am very new to this and trying to get a hold of things.

Player1Env looks upon a 7x6 Connect4 playing grid. I am initializing the class as follows:

def __init__(self):
    super(Player1Env, self).__init__()
    self.action_space = spaces.Discrete(7)
    self.observation_space = spaces.Box(low=-1, high=1, shape=(7, 6), dtype=np.float32)

checking if the class is instantiated correctly with

env = Player1Env()
check_env(env)

returns the error

AssertionError: The observation returned by the `reset()` method does not match the given observation space

printing the observation returned by the reset function and its shape:

[[0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]
(7, 6)

low and high are defined as -1 and 1 respectively since the grid represents the current boardstate, with 1 being the stones dropped in by player 1 and -1 the stones dropped in by player 2. This part of the code has been tested extensively, but even changing the boundaries to -np.inf and np.inf does not change the error message.

The reset function itself:

def reset(self):
    self.board = np.zeros((7, 6))
    self.player = 1

    self.reward = 0
    self.done = False

    observation = self.board

    return observation

The stepping function is pitting the rl algorithm against a preprogrammed agent, but the error should be coming from the reset function anyways.

Could you help me out with where the error is coming from?

Edit: There was a UserError with the numpy API compiling against the wrong version that didn't seem to impact usability (everything worked in the premade gym environments). I managed to fix that error, but the observation space definition problem still persists.



Solution 1:[1]

Your solution:

If you define self.board in the reset() as below your problem is solved:

self.board = np.zeros((7, 6), dtype=np.float32)

More details and examples about is presented in the end of the answer


General answer: minimal example of custom env with Box observation space in gym

dtype in box and observation should be same. Here both are considered float32

class customEnv(Env):
    def __init__(self):  
        self.action_space =  Box(low=np.array([0.0]), high=np.array([1]))
        self.observation_space = Box(low=np.array([0.0, 0.0]), high=np.array([1.0, 1.0]))
        self.state =  np.array([0.5, 0.5], dtype=np.float32)

    def step(self, action):
        state = self.state
        # below variables should be defined in order to prevent error in check_env
        reward = 1
        done = False
        info = {}
        return self.state, reward, done, info

    def reset(self):
        self.state = np.array([0.5, 0.5], np.float32)   # np.float32 is essential
        return self.state
    def render(self):
        pass

env = customEnv()
check_env(env, warn=True)

Example about numpy and gym.spaces.Box dtype:

When you define custom env in gym, check_env checks several things. In this case, observation.isinstance(observation_space) is not passed.

In this case, self.board (or the variable named observation in method named reset()) is not an instance of the observation_space. because observation.dtype = float64 and observation_space.dtype = float32.

Default dtype in numpy object is float64 and default dtype in Boxobject is float32. versions: numpy 1.21.5, gym 0.21.0

import nump as np
import gym 
from gym.spaces import Box


# example 1; by this definition you get error
In [1]: observation_space = Box(low=np.array([0.0, 0.0]), high=np.array([1.0, 1.0]))
In [2]: observation = np.array([0.5, 0.5])
In [3]: print(observation.dtype)
In [4]: observatin_space.contains(observation)     # does observation_space contains observation?

out[3]: float64
out[4]: False

# example 2; this definition works fine; no error
In [10]: observation_space_2 = Box(low=np.array([0.0, 0.0]), high=np.array([1.0, 1.0]))
In [11]: observation_2 = np.array([0.5, 0.5], dtype=np.float32)
In [12]: print(observation_2.dtype)
In [13]: observatin_space_2.contains(observation_2)     # does observation_space contains observation?

out[12]: float32
out[13]: True

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1