'How to reset the state of an LSTM RNN after each epoch within Keras?

I have defined a stateful LSTM RNN, and I want to reset the state of the RNN after each epoch. I have found that one way to do this would be:

n_epochs = 50
for i in range(n_epochs):
    lstm.fit(X, y, epochs = 1, batch_size = 64)
    lstm.reset_states()

Is there any other more elegant way to implement this in the model specification or when training that is supported by Keras?



Solution 1:[1]

You should be able to solve this with a Keras callback, which probably a bit more elegant:

import tensorflow as tf

class CustomCallback(tf.keras.callbacks.Callback):
   def on_epoch_end(self, epoch, logs=None):
        lstm_layer.reset_states()

inputs = tf.keras.layers.Input(batch_shape = (10, 5, 2))
x = tf.keras.layers.LSTM(10, stateful=True)(inputs)
outputs = tf.keras.layers.Dense(1, activation='linear')(x)
model = tf.keras.Model(inputs, outputs)

lstm_layer = model.layers[1]

model.compile(optimizer='adam', loss='mse')
x = tf.random.normal((200, 5, 2))
y = tf.random.normal((200, 1))

model.fit(x, y, epochs=5, callbacks=[CustomCallback()], batch_size=10)

Solution 2:[2]

For experiments only, everyone knows when working for multiple steps and you set all input values back to 0 for all DATA ( long potential enough or the same number as input ) in the batch that reset all memories of LSTM. That is the behavior of LSTM since they are sensitive to input because it contains comparison units and summation units.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 Martijn Pieters