'A strange problem occur when I train my model using pytorch

I was training my face expression recognition model, but my test set accuracy is always low. So I try to find the problem.

I replace my testloader with trainloader, and I hope that I will get the same accuracy. But then the problem occur, they are different. Why? Doesn't I use the same dataset?

A result in an epoch is:

Loss in trainset: 1.5687

Accuracy in trainset: 18%

Accuracy of     anger :  5%
Accuracy of   disgust : 11%
Accuracy of      fear : 20%
Accuracy of     happy : 34%
Accuracy of   neutral : 25%
Accuracy of       sad : 12%
Accuracy of  surprise : 18%

-------------------------calculating accuracy-------------------------

Loss in testset: 1.6002 

Accuracy in testset: 13%

Accuracy of     anger :  0%
Accuracy of   disgust :  0%
Accuracy of      fear :  0%
Accuracy of     happy :  0%
Accuracy of   neutral : 100%
Accuracy of       sad :  0%
Accuracy of  surprise :  0%

My train and test code is here. Could someone help me? Thank you very much.

classes = ('anger', 'disgust', 'fear', 'happy', 'neutral', 'sad', 'surprise')
trainloader = torch.utils.data.DataLoader(trainset, shuffle=False, batch_size=args.bn)

model.train()

total_loss = 0.0
correct = 0.0
total = 0
sum_class = [0. for i in range(args.classnum)]
num_class = [0. for i in range(args.classnum)]

for inputs, labels in trainloader:
    # get the inputs
    inputs, labels = inputs.to(device), labels.to(device)

    # count data
    total = total + labels.shape[0]

    # wrap them in Variable
    inputs, labels = Variable(inputs), Variable(labels)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    total_loss += loss.item()
    loss.backward()
    optimizer.step()

    results = torch.argmax(outputs, dim=1)
    correct = correct + (labels == results).sum().item()

    labels = list(labels)
    # calculate class accuracy
    for i in range(7):
        sum_class[i] = sum_class[i] + labels.count(i)
    for i in range(len(results)):
        if results[i] == labels[i]:
            num_class[results[i]] = num_class[results[i]] + 1

total_loss = total_loss / ( 1 if total == 0 else total)
total_accuracy = correct / (1 if total == 0 else total)
print('Loss in trainset: %.4f\n' % total_loss)
print('Accuracy in trainset: %d%%\n' % (total_accuracy*100))
for i in range(args.classnum):
    print('Accuracy of %9s : %2d%%' % (classes[i], 100 * num_class[i] / sum_class[i]))

print('\n-------------------------calculating accuracy-------------------------\n')


model.eval()

correct = 0
total = 0
total_loss = 0.0

sum_class = [0. for i in range(args.classnum)]
num_class = [0. for i in range(args.classnum)]

# I try to replace the testloader with trainloader, but I get the different results.
*emphasized text*
# for inputs, labels in testloader:
for inputs, labels in trainloader:
    # get the inputs
    inputs, labels = inputs.to(device), labels.to(device)

    # count data
    total = total + labels.shape[0]

    # wrap them in Variable
    inputs, labels = Variable(inputs), Variable(labels)

    outputs = model(inputs)
    loss = criterion(outputs, labels)
    total_loss += loss.item()

    results = torch.argmax(outputs, dim=1)
    correct = correct + (labels == results).sum().item()

    labels = list(labels)
    # calculate class accuracy
    for i in range(7):
        sum_class[i] = sum_class[i] + labels.count(i)
    for i in range(len(results)):
        if results[i] == labels[i]:
            num_class[results[i]] = num_class[results[i]] + 1

loss = (total_loss / (1 if total == 0 else total))
accuracy = (100 * correct / (1 if total == 0 else total))
print('Loss in testset: %.4f \n' % loss)
print('Accuracy in testset: %d%%\n' % accuracy)

for i in range(args.classnum):
    print('Accuracy of %9s : %2d%%' % (classes[i], 100 * num_class[i] / sum_class[i]))


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source