'How to select Q value in DQN where Q is a multi-dimensional array
I'm implementing a DQN to do the Trading in the stock market (for educational purposes only)
I have this data and the shape of the data. This is the state in a time series data and I'm going to pass it to a nerual network. The first column is the Closing price of a stock, and the second column is the Volume (Normalized already):
array([[[-0.39283217, 3.96508668],
[-0.39415516, 0.04931261],
[-0.38271683, -0.34029827],
[-0.39283217, -0.42384451],
[-0.4332384 , -0.11795849],
[-0.41201548, -0.47441503],
[-0.41739012, -0.51788375],
[-0.42210326, -0.60101319],
[-0.43660099, -0.596672 ],
[-0.43660099, -0.64244935]]])
(1, 10, 2)
Now I pass this data to a neural network. It's essentially a policy network, but to simplify the question I write it like this here (The loss is Q - target Q value):
model = keras.Sequential([
keras.layers.Input(shape=(10,2,)),
keras.layers.Dense(10, activation='relu'),
keras.layers.Dense(3, activation='linear')
])
model.compile(loss=count_the_loss(),
optimizer='adam',
metrics='mse')
Now I get this by using the predict function:
array([[[-0.79352564, -0.22876596, 2.309589 ],
[-0.10996505, 0.01430818, 0.22286436],
[-0.17374574, 0.03645202, 0.10073717],
[-0.19824156, 0.07159233, 0.08594725],
[-0.12234195, 0.03734204, 0.19439939],
[-0.21589771, 0.088783 , 0.08315123],
[-0.22866695, 0.10703149, 0.07550874],
[-0.25188142, 0.1436682 , 0.05827002],
[-0.25386256, 0.13714936, 0.06612003],
[-0.26608405, 0.1581351 , 0.05540368]]], dtype=float32)
I'm supposed to get the q(s,a1), q(s,a2), q(s,a3) (where a1, a2 and a3 stands for actions: short, flat and long respectively), then find the q for the action sampled from the experience replay.
But now I get a 1x10x3 array.
My questions are:
How am I supposes to get the q?
And when this is done, it's time to find the target Q. It's similar to the process above. Suppose the above result is what I get by passing the next_state to a target network. I have to find the q max. How can I find q max in a 1x10x3 array?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
