'Why does the CNN predicts the same class, while the training loss keeps decreasing in a binary classification task?

Network description:
I am using a CNN for a binary classification task. The network structure is described in this article.

I modified the kernel size of the last CONV layer, so the output of this layer contains 2 channels instead of only one. (Since the network should yield 2 values in a binary classification problem, I'm not sure if the article make a mistake here. But the modified network works fine though.)

Here is the modified model:

class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, ks1, ks2, use_1x1conv=False, stride=1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels,
                               out_channels,
                               kernel_size=ks1,
                               padding=(ks1 - 1) // 2,
                               stride=stride)
        self.conv2 = nn.Conv2d(out_channels,
                               out_channels,
                               kernel_size=ks2,
                               padding=(ks2 - 1) // 2)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels,
                                   out_channels,
                                   kernel_size=1,
                                   stride=stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, X):
        Y = Func.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return Func.relu(Y + X)


class PaulNet(nn.Module):
    def __init__(self):
        super(PaulNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=128,
                               kernel_size=5, )
        self.bn1 = nn.BatchNorm2d(128)
        self.res1 = ResBlock(in_channels=128,
                             out_channels=128,
                             ks1=3,
                             ks2=1,
                             use_1x1conv=False)
        self.mp1 = nn.MaxPool2d(kernel_size=2)
        self.res2 = ResBlock(in_channels=128,
                             out_channels=128,
                             ks1=3,
                             ks2=3,
                             use_1x1conv=False)
        self.mp2 = nn.MaxPool2d(kernel_size=2)
        self.res3 = ResBlock(in_channels=128,
                             out_channels=128,
                             ks1=3,
                             ks2=3,
                             use_1x1conv=False)
        self.res4 = ResBlock(in_channels=128,
                             out_channels=128,
                             ks1=3,
                             ks2=3,
                             use_1x1conv=False)
        self.mp3 = nn.MaxPool2d(kernel_size=2)
        self.res5 = ResBlock(in_channels=128,
                             out_channels=256,
                             ks1=1,
                             ks2=1,
                             use_1x1conv=True)
        self.res6 = ResBlock(in_channels=256,
                             out_channels=512,
                             ks1=1,
                             ks2=1,
                             use_1x1conv=True)
        self.res7 = ResBlock(in_channels=512,
                             out_channels=512,
                             ks1=1,
                             ks2=1,
                             use_1x1conv=False)
        self.conv2 = nn.Conv2d(in_channels=512,
                               out_channels=2,
                               kernel_size=1, )
        self.bn2 = nn.BatchNorm2d(2)
        self.gap = nn.AvgPool2d(kernel_size=15)

    def forward(self, X):
        Y = Func.relu(self.bn1(self.conv1(X)))
        Y = self.mp1(self.res1(Y))
        Y = self.mp2(self.res2(Y))
        Y = self.mp3(self.res4(self.res3(Y)))
        Y = self.res7(self.res6(self.res5(Y)))
        Y = self.gap(self.bn2(self.conv2(Y)))
        return Y.view(-1, 2, 1)

I trained the model for 400 epochs in the first place, and used a balanced data set of 1 target class and 1 noise class with 1800 samples of each class. In every training epoch, 60% of the samples are used for training and 20% for validation. The rest 20% are left out for testing.

The training loss, recall rate(true positives) and false alarm rate(False negatives) are recorded for observation.
Other hyper parameters are:

  • batch size: 64
  • learning rate: 0.0001
  • adam_betas = (0.9, 0.99)

My problem:
The training result is shown here: diagram

The training loss keeps decreasing, but the trend of the performance is weird. Ideally, the recall rate should be increasing and the fa rate should be decreasing. But in my case, the recall rate and the fa rate are both always 1 (but not all the time). They sometimes change greatly for only 1 epoch, but return to 1 immediately in the next epoch. This means the model predicts the same class for all the input data.

Although in an epoch around 240, the model gives out a perfect performance of a 100% recall rate and 0% fa rate(which is further tested on the test data set, showing the same good result to prove that the model isn't over-fitting), the model still predicts the same class most of the time. The training loss convergents, but the performance is not stable at all.

What I've Tried:

  1. Observe the sample data during training and validation process:
    My training and validation codes are here below:
def train_one_epoch(data_loader, loss_function, optimizer, train_model):
    train_model.train()
    size = len(data_loader.dataset)
    for batch, data in enumerate(data_loader):
        iter_start = time.perf_counter()
        feature, ground_truth = data
        feature, ground_truth = feature.to(device), ground_truth.to(device)
        # forward
        prediction = train_model(feature)
        loss = loss_function(prediction, ground_truth)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # iteration time & training loss
        iter_dur = time.perf_counter() - iter_start
        print(f'{batch}/{size // len(feature)} iters '
              f'in {iter_dur:.2f}s, loss = {loss.item():.6f}')
    return loss.item()


def valid_one_epoch(data_loader, loss_function, valid_model):
    valid_model.eval()
    size = len(data_loader.dataset)
    target, noise = 0, 0
    valid_loss, hit, fa = 0, 0, 0
    with torch.no_grad():
        for batch, data in enumerate(data_loader):
            iter_start = time.perf_counter()
            feature, ground_truth = data
            feature, ground_truth = feature.to(device), ground_truth.to(device)
            # forward
            prediction = valid_model(feature)
            loss = loss_function(prediction, ground_truth)
            # statics
            valid_loss += loss.item()
            if ground_truth.item() == 1:
                target += 1
                hit += (prediction[0].argmax(dim=0).item() == 1)
            else:
                noise += 1
                fa += (prediction[0].argmax(dim=0).item() == 1)
        # recall & fa
        recall = hit / target
        fa = fa / noise
        val_loss = valid_loss / size
        print(f'target: {target}, hit: {hit}, noise:{noise}, fa: {fa}, valid loss: {val_loss:.6f}')
    return recall, fa, val_loss

I find that in the forward path during training, the model can give the correct prediction for different input samples. This explains why the training loss is steadily decreasing.

However, the forward path during validation cannot make the right prediction. The only differences is that 1) I set model.train() during training and model.eval() during validation, and 2) I set torch.no_grad() during validation.

I have no idea why this happens, so I assume that it was because of the difference between training and validation sets, so I made another attempt:

  1. use the same data set for both training and validation process:
    This time I use 100% identical data set for training and validation. I reduce training epoch to 100, reduced the data set scale to 900 samples each class, to speed up training. I also recorded the validation loss. The training loss and validation loss should be rather close on the same data set during the process.

But they did not. The training loss keeps decreasing, while the validation loss fluctuate greatly with a general trend of slight increase. The result is shown here:diagram

This is all about my question and thanks a lot for reaching this far. Hope that I've provided all the necessary information to it.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source