'LSTM/GRU setting states to random noise instead or resetting to zero

I train the following model based on GRU, note that I am passing the argument stateful=True to the GRU builder.

class LearningToSurpriseModel(tf.keras.Model):
  def __init__(self, vocab_size, embedding_dim, rnn_units):
    super().__init__(self)
    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
    self.gru = tf.keras.layers.GRU(rnn_units,
                                   stateful=True,
                                   return_sequences=True,
                                   return_state=True,
                                   reset_after=True  
                                   )
    self.dense = tf.keras.layers.Dense(vocab_size)

  def call(self, inputs, states=None, return_state=False, training=False):
    x = inputs
    x = self.embedding(x, training=training)
    if states is None:
      states = self.gru.get_initial_state(x)
    x, states = self.gru(x, initial_state=states, training=training)
    x = self.dense(x, training=training)

    if return_state:
      return x, states
    else:
      return x

  @tf.function
  def train_step(self, inputs):
    [defining here my training step]

I instantiate my model

model = LearningToSurpriseModel(
    vocab_size=len(ids_from_chars.get_vocabulary()),
    embedding_dim=embedding_dim,
    rnn_units=rnn_units
    )

[compile and do stuff] the custom callback below reset states manually at the end of each epoch.

gru_layer = model.layers[1]

class CustomCallback(tf.keras.callbacks.Callback):
   def __init__(self, gru_layer):
        self.gru_layer = gru_layer
   def on_epoch_end(self, epoch, logs=None):
        self.gru_layer.reset_states()
        
model.fit(train_dataset, validation_data=validation_dataset, \
    epochs=EPOCHS, callbacks = [EarlyS, CustomCallback(gru_layer)], verbose=1)

States will be reset to zero. I would like to follow ideas in https://r2rt.com/non-zero-initial-states-for-recurrent-neural-networks.html to (re)initialize states to a random noise. What would be a good implementation for random noise?

Should I overwrite reset_states() adding a states parameter ?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source