'DQN model for Atari game never learns
I'm trying to implement an Pong game with DQN model by torch. However I got two problems during the execution. Firstly, I found that the game never get done. Secondly, I found the loss function does not have any change in the trainning. This is my code below: I defined a CNN network with the input of the size (batch=32, channels=4, height=84, weight=84). By this step there's nothing wrong happened:
class CNN(nn.Module):
def __init__(self, s_channels, a_space):
super(CNN, self).__init__()
self.pool = nn.MaxPool2d(kernel_size=2, stride=1)
self.conv1 = nn.Conv2d(s_channels,out_channels=32,kernel_size=8,stride=4)
self.conv2 = nn.Conv2d(32,64,4,2)
self.conv3 = nn.Conv2d(64,64,3,1)
self.fc1 = nn.Linear(64*4*4,1024)
self.fc2 = nn.Linear(1024,512)
self.fc3 = nn.Linear(512,a_space)
def forward(self,input):
output = self.pool(F.relu(self.conv1(input)))
output = self.pool(F.relu(self.conv2(output)))
output = self.pool(F.relu(self.conv3(output)))
output = output.view(-1,64*4*4)
output = F.relu(self.fc1(output))
output = F.relu(self.fc2(output))
output = F.relu(self.fc3(output))
return output
For the agent class, I defined a back propagation function to replay the weight in CNN and the data pre-processing function:
# Agent
class Agent():
def __init__(self, s_space, a_space) -> None:
# define parameters
self.epsilon = 1.0
self.min_epsilon = 0.01
self.dr = 0.995
self.lr = 0.001
self.gamma = 0.9
# define models
self.evl_net = CNN(s_space, a_space)
self.tgt_net = CNN(s_space, a_space)
self.cert = nn.SmoothL1Loss()
self.optimal = th.optim.Adam(self.evl_net.parameters(),lr=self.lr)
# define memory store
self.memory = deque(maxlen=2000)
# self.img_stack = deque(maxlen=4)
# pre-processing frame images: transform the imaages into tensors
# def bsl_image_pre_process(self,env):
# env = aw.AtariWrapper(env,noop_max=30,frame_skip=4,screen_size=84,terminal_on_life_loss=True,clip_reward = True)
# return env
def gym_image_pre_process(self,env):
#Atari preprocessing
env = gym.wrappers.AtariPreprocessing(env, noop_max=30, frame_skip=4, screen_size=84, terminal_on_life_loss=False, grayscale_obs=True, grayscale_newaxis=False, scale_obs=False)
#create frame stack
env = gym.wrappers.FrameStack(env, 4)
channels = env.observation_space.shape[0]
return env,channels
# env = aw.AtariWrapper(env,noop_max=30,frame_skip=4,screen_size=84,terminal_on_life_loss=True,clip_reward = True)
# return env
def data_pre_process(self,batch_size):
s_v = []
a_v = []
next_s_v = []
r_v = []
dones = []
materials = random.sample(self.memory,batch_size)
for t in materials:
s_v.append(t[0])
a_v.append(t[1])
next_s_v.append(t[2])
r_v.append(t[3])
dones.append(t[4])
# print(th.FloatTensor(r_v))
# print(th.FloatTensor(r_v).size())
# print(s_v)
s_v = th.Tensor(s_v) # size: [32,3,210,160]
a_v = th.LongTensor(a_v).unsqueeze(1) # size: [32,1]
next_s_v = th.Tensor(next_s_v) # size: [32,3,210,160]
r_v = th.FloatTensor(r_v) # size: [32]
return s_v, a_v, next_s_v, r_v, dones
# remember the transformed images
def record(self,tpl):
self.memory.append(tpl)
# select actions according to the states (input images with 4 channels)
def select(self,state,a_space):
actions = self.evl_net(state).data.tolist()
if(random.random() <= self.epsilon):
action = random.randint(0,a_space-1)
else:
action = actions.index(max(actions))
return action
# DQN trainning progression
def train(self,state,batch_size):
s_v,a_v,next_s_v,r_v,dones = self.data_pre_process(batch_size)
self.tgt_net.load_state_dict(self.evl_net.state_dict())
evl_Q_value = self.evl_net(s_v).gather(0,a_v) # size: [32,6].gather() -> [32,1]
tgt = self.tgt_net(next_s_v).max(1)[0].detach() # size [32,1]
tgt_Q_value = (r_v + self.gamma * tgt)
for index in range(len(dones)):
if(dones[index]==True):
tgt[index][0] = -1
# print(tgt_Q_value)
tgt_Q_value = tgt_Q_value.reshape(batch_size,1) # size: [32, 1] cannot be back propagated
# print(tgt_Q_value)
self.optimal.zero_grad()
loss = self.cert(evl_Q_value, tgt_Q_value)
print(loss)
loss.backward()
for pr in self.evl_net.parameters():
pr.grad.data.clamp_(-1, 1)
self.optimal.step()
if(self.epsilon > self.min_epsilon):
self.epsilon *= self.dr
At the training stage, I found the first question. the condition of done in each episode is always false. With gym.wrappers I've pre-processed the image tensor into 48484 and the environment with only one life. But it still appears:
# main test
_display = Display(visible=0, size=(900,1400))
_display.start()
# set episode step and batch_size
episodes = 5000
batch_size = 32
env = gym.make("PongNoFrameskip-v4")
env = gym.wrappers.AtariPreprocessing(env, noop_max=30, frame_skip=4, screen_size=84, terminal_on_life_loss=False, grayscale_obs=True, grayscale_newaxis=False, scale_obs=False)
# create frame stack for the input image data (size: (4,84,84))
env = gym.wrappers.FrameStack(env, 4)
channels = env.observation_space.shape[0]
a_space = env.action_space.n
agent = Agent(channels, a_space)
# env.render()
# testing:
for e in range(episodes):
# step 1: reset the agent at the beginning
s = np.array(env.reset())
for run in range(100):
score = 0
# display.clear_output(wait=True)
# display.display(Image.fromarray(env.render(mode='rgb_array')))
# env.render("rgb_array")
img = plt.imshow(env.render('rgb_array'))
# step 2: create state space tensor
# step 3: iterate actions
a = agent.select(th.Tensor(s).unsqueeze(0),a_space)
next_s, reward, done, _ = env.step(a)
if(done==True):
next_s = None
next_s = np.array(next_s) # done is never true. Why?
# step 4: record the data into buffer
dataset = (s,a,next_s,reward,done)
agent.record(dataset)
# step 5: update state steps
s = next_s
score += reward
if(done==True or run == 99):
print("episodes:",e,"score:",score,"epsilon: {:.2}".format(agent.epsilon))
break
# step 6: training and update CNN
if(len(agent.memory) > batch_size):
agent.train(channels,batch_size)
As I tried to find this problem, I detected that the loss value never even roughly decreases(at most fluctuate around 1.2). I rechecked the input and output tensor but found nothing else. I hope to get some help for how to fix these two problems. Many thanks!
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
