'Fluctuations and overfitting in first epochs

I am training a CNN network on the DVS gesture dataset using PyTorch. However, the training is not progressing in a soft way, the accuracies of both training and validation fluctuate a lot, they are both progressing, but there is a big difference between them (5~6% up to 10%) as if there is overfitting in 3/4 epoch. I have tried L2 regularization as well as a dropout with high values, the difference disappears in the first iterations but reappears strongly afterward, and I am sure that datasets are perfectly merged and split randomly, changed several times the batch size but didn't impact, normalization make it worse.

PS: May this be an underfit, how to identify an underfit ?

Thanks in advance!

CODE (Using snntorch library) :

spike_grad = surrogate.fast_sigmoid(slope=5.4)
beta = 0.72
num_epochs = 200

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # Initialize layers
        self.conv1 = nn.Conv2d(2, 16, kernel_size=5, bias=False)
        self.pool1 = nn.AvgPool2d(2)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=2.5)#, threshold_p=2.5, threshold_n=-2.5)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, bias=False)
        self.pool2 = nn.AvgPool2d(2)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=2.5)#, threshold_p=2.5, threshold_n=-2.5)
        
        self.fc1 = nn.Linear(800, 11)
        self.drop1 = nn.Dropout(0.93)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad, threshold=2.5)#, threshold_p=2.5, threshold_n=-2.5)

        self.flatten = nn.Flatten()

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        spk_rec = []
        mem_rec = []
        for step in range(x.size(1)):
            cur1 = self.pool1(self.conv1((x.permute(1,0,2,3,4))[step]))
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = self.pool1(self.conv2(spk1))
            spk2, mem2 = self.lif2(cur2, mem2)

            cur3 = self.drop1(self.fc1(self.flatten(spk2)))
            spk3, mem3 = self.lif3(cur3, mem3)

            spk_rec.append(spk3)
            mem_rec.append(mem3)
        return torch.stack(spk_rec), torch.stack(mem_rec)

net_9 = Net().to(device)
optimizer = torch.optim.Adam(net_9.parameters(), lr=7.5e-3, betas=(0.9, 0.999))#, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=735, eta_min=0, last_epoch=-1)
loss = SF.mse_count_loss() # spk mse

train_loss_hist_9 = []
valid_loss_hist_9 = []
train_acc_hist_9 = []
valid_acc_hist_9 = []
path_9 = "1-DVS\net_9_"

for epoch in range(num_epochs):
    batch_train = batch_valid = 0

    # Minibatch training loop
    net_9.train()
    for data_train, targets_train in iter(train_loader):
        data_train = data_train.to(device)
        targets_train = targets_train.to(device)
        spk_train, mem_train = net_9.forward(data_train)
        loss_train = loss(spk_train, targets_train)
        optimizer.zero_grad()
        loss_train.backward()
        optimizer.step()
        scheduler.step()
        _, idx = spk_train.sum(dim=0).max(1)
        acc_train = np.mean((targets_train == idx).detach().cpu().numpy())
        train_acc_hist_9.append(acc_train.item())
        train_loss_hist_9.append(loss_train.item())
        batch_train += 1

    # Minibatch validation loop
    net_9.eval()
    with torch.no_grad():
        for data_valid, targets_valid in iter(valid_loader):
            data_valid = data_valid.to(device)
            targets_valid = targets_valid.to(device)
            spk_valid, mem_valid = net_9.forward(data_valid)
            loss_valid = loss(spk_valid, targets_valid)
            _, idx = spk_valid.sum(dim=0).max(1)
            acc_valid = np.mean((targets_valid == idx).detach().cpu().numpy())
            valid_acc_hist_9.append(acc_valid.item())
            valid_loss_hist_9.append(loss_valid.item())
            batch_valid += 1

    scheduler.step(loss_valid)
    torch.save({'model_state_dict': net_9.state_dict()}, path_9 + str(epoch))

    print("----------------------------------------------------------------------")
    print_epoch_accuracy(train_acc_hist_9, valid_acc_hist_9, batch_train, batch_valid)
    print("----------------------------------------------------------------------")
    print("\n")


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source