'How to select Q value in DQN where Q is a multi-dimensional array

I'm implementing a DQN to do the Trading in the stock market (for educational purposes only)

I have this data and the shape of the data. This is the state in a time series data and I'm going to pass it to a nerual network. The first column is the Closing price of a stock, and the second column is the Volume (Normalized already):

array([[[-0.39283217,  3.96508668],     
        [-0.39415516,  0.04931261],
        [-0.38271683, -0.34029827],
        [-0.39283217, -0.42384451],
        [-0.4332384 , -0.11795849],
        [-0.41201548, -0.47441503],
        [-0.41739012, -0.51788375],
        [-0.42210326, -0.60101319],
        [-0.43660099, -0.596672  ],
        [-0.43660099, -0.64244935]]])

(1, 10, 2)

Now I pass this data to a neural network. It's essentially a policy network, but to simplify the question I write it like this here (The loss is Q - target Q value):

model = keras.Sequential([
            keras.layers.Input(shape=(10,2,)),
            keras.layers.Dense(10, activation='relu'),
            keras.layers.Dense(3, activation='linear')
        ]) 

model.compile(loss=count_the_loss(),
              optimizer='adam',
              metrics='mse')

Now I get this by using the predict function:

array([[[-0.79352564, -0.22876596,  2.309589  ],
        [-0.10996505,  0.01430818,  0.22286436],
        [-0.17374574,  0.03645202,  0.10073717],
        [-0.19824156,  0.07159233,  0.08594725],
        [-0.12234195,  0.03734204,  0.19439939],
        [-0.21589771,  0.088783  ,  0.08315123],
        [-0.22866695,  0.10703149,  0.07550874],
        [-0.25188142,  0.1436682 ,  0.05827002],
        [-0.25386256,  0.13714936,  0.06612003],
        [-0.26608405,  0.1581351 ,  0.05540368]]], dtype=float32)

I'm supposed to get the q(s,a1), q(s,a2), q(s,a3) (where a1, a2 and a3 stands for actions: short, flat and long respectively), then find the q for the action sampled from the experience replay.

But now I get a 1x10x3 array.

My questions are:

  1. How am I supposes to get the q?

  2. And when this is done, it's time to find the target Q. It's similar to the process above. Suppose the above result is what I get by passing the next_state to a target network. I have to find the q max. How can I find q max in a 1x10x3 array?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source