'Training on a saved model after learning has stagnated?

I am training a single-layer neural network using PyTorch and saving the model after the validation loss decreases. Once the network has finished training, I load the saved model and pass my test set features through that (rather than the model from the last epoch) to see how well it does. However, more often that not, the validation loss will stop decreasing after about 150 epochs, and I'm worried that the network is overfitting the data. Would it be better for me to load the saved model during training if the validation loss has not decreased for some number of iterations (say, after 5 epochs), and then train on that saved model instead?

Also, are there any recommendations for how to avoid a situation where the validation loss stops decreasing? I've had some models where the validation loss continues to decrease even after 500 epochs and others where it stops decreasing after 100. Here is my code so far:

class NeuralNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, nodes):
        super(NeuralNetwork, self).__init__()
        self.linear1 = nn.Linear(input_dim, nodes)
        self.tanh = nn.Tanh()
        self.linear2 = nn.Linear(nodes, output_dim)

    def forward(self, x):
        output = self.linear1(x)
        output = self.tanh(output)
        output = self.linear2(output)
        return output

epochs = 500 # (start small for now)
learning_rate = 0.01
w_decay = 0.1
momentum = 0.9
input_dim = 4
output_dim = 1
nodes = 8
model = NeuralNetwork(input_dim, output_dim, nodes)

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=w_decay) 
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)

losses = []
val_losses = []
min_validation_loss = np.inf
means = [] # we want to store the mean and standard deviation for the test set later
stdevs = []
torch.save({
    'epoch': 0,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'training_loss': 0.0,
    'validation_loss': 0.0,
    'means': [],
    'stdevs': [],
    }, new_model_path)
new_model_saved = True

for epoch in range(epochs):
    curr_loss = 0.0
    validation_loss = 0.0

    if new_model_saved:
        checkpoint = torch.load(new_model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        means = checkpoint['means']
        stdevs = checkpoint['stdevs']
        new_model_saved = False

    model.train()
    for i, batch in enumerate(train_dataloader):
        x, y = batch
        x, new_mean, new_std = normalize_data(x, means, stdevs)
        means = new_mean
        stdevs = new_std
        optimizer.zero_grad()
        predicted_outputs = model(x)
        loss = criterion(torch.squeeze(predicted_outputs), y)
        loss.backward()
        optimizer.step()
        curr_loss += loss.item()

    model.eval()
    for x_val, y_val in val_dataloader:
        x_val, val_means, val_std = normalize_data(x_val, means, stdevs)
        predicted_y = model(x_val)
        loss = criterion(torch.squeeze(predicted_y), y_val)
        validation_loss += loss.item()

    curr_lr = optimizer.param_groups[0]['lr']
    if epoch % 10 == 0:
        print(f'Epoch {epoch} \t\t Training Loss: {curr_loss/len(train_dataloader)} \t\t Validation Loss: {validation_loss/len(val_dataloader)} \t\t Learning rate: {curr_lr}')
    if min_validation_loss > validation_loss:
        print(f'     For epoch {epoch}, validation loss decreased ({min_validation_loss:.6f}--->{validation_loss:.6f}) \t learning rate: {curr_lr} \t saving the model')
        min_validation_loss = validation_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'training_loss': curr_loss/len(train_dataloader),
            'validation_loss': validation_loss/len(val_dataloader),
            'means': means,
            'stdevs': stdevs
            }, new_model_path)
        new_model_saved = True

    losses.append(curr_loss/len(train_dataloader))
    val_losses.append(validation_loss/len(val_dataloader))
    scheduler.step(curr_loss/len(train_dataloader))


Solution 1:[1]

The phenomenon of the validation loss increases whereas the training loss decreases is called overfitting. Overfitting is a problem when training a model and should be avoided. please read more on this topic here. Overfitting may occur after any number of epochs and id dependent on a lot of variables(learning rate, database side, database diversity and more). as a rule of thumb, test your model at the "pivot point", i.e. exactly where the validation loss begins to increase (and the training continues to decrease). This means that my recommendation is to save the model after each iteration where the validation loss decreases. If it keeps increasing after any X number of epochs, it probably means that you reach a "deep" minimum for the loss and it will not be beneficial to keep training (again, this has some exceptions but for this level of discussion it is enough) I encourage you to read and learn more about this subject, It is very interesting and has significant implications.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Tomer Geva