'How can I implement early stopping and reduce learning rate on plateau in Tensorflow?

I want to implement two callbacks EarlyStopping and ReduceLearningRateOnPlateau for a neural network model constructed by using tensorflow. (I'm not using Keras)

The sample code below is how I implement early stopping in the script I wrote, I don't know whether it is correct or not.

# A list to record loss on validation set
val_buff = []
# If early_stop == True, then terminate training process
early_stop = False

while icount < maxEpoches:

    '''Shuffle the training set'''
    '''Update the model by using Adam optimizer over the entire training set'''

    # Evaluate loss on validation set
    val_loss = self.sess.run(self.loss, feed_dict = feeddict_val)
    val_buff.append(val_loss)

    if icount % ep == 0:

        diff = np.array([val_buff[ind] - val_buff[ind - 1] for ind in range(1, len(val_buff))])
        bad = len(diff[diff > 0])
        if bad > 0.5 * len(diff):
            early_stop = True

        if early_stop:
            self.saver.save(self.sess, 'model.ckpt')
            raise OverFlow()
        val_buff = []

    icount += 1

When I train the model and keep track of the loss on validation set, I find the loss goes up and down, so it is hard to tell when the model starts to overfit.

Since Earlystopping and ReduceLearningRateOnPlateau are quite similar, how can I modify the code above to implement ReduceLearningRateOnPlateau?



Solution 1:[1]

Oscillating error/loss is pretty common. The main issue with implementing early stopping or learning rate decrease rule is that validation loss calculation happens relatively rear. To fight this problem I might suggest next rule: stop training when the best validation error is at least N epochs past.

max_stagnation = 5 # number of epochs without improvement to tolerate
best_val_loss, best_val_epoch = None, None

for epoch in range(max_epochs):
    # train an epoch ...
    val_loss = evaluate()
    if best_val_loss is None or best_val_loss < val_loss:
        best_val_loss, best_val_epoch = val_loss, epoch
    if best_val_epoch < epoch - max_stagnation:
        # nothing is improving for a while
        early_stop = True
        break  

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 y.selivonchyk