'Neural Network predicting same label for all points

I'm doing segmentation of a point cloud using PointNet. I have trained a model using the dataset from this link http://buildingparser.stanford.edu/dataset.html, and during the training, everything looks good in terms of loss and accuracy. However, when I try to afterward make predictions using the trained model, all predictions are the same, in particular label 2 which corresponds to 'wall'. The dataset that is used for the training does not have evenly distributed data, but still, it seems weird that it predicts everything as 'wall'.

from torch.utils.data import Dataset, DataLoader
import torch
import os
import h5py
import numpy as np
import sys
from torch import nn
import torch.optim as optim
from matplotlib import pyplot
from torch.utils.data import TensorDataset
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.tensorboard import SummaryWriter
import torchvision


from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D

#Own classes
from network import PointNetSeg
from loss import PointNetLoss
import ExtractH5Data


class Trainer:
    def __init__(self):

        # TrainingParameters:
        self.epochCount = []
        self.lossCount = []

        self.batch_size = 2
        self.lr = 0.001
        self.n_epochs = 15
        self.model_path = "/Users/Mikke/PycharmProjects/pointnet/model/model.pth"
        self.load_model = True
        self.compute_validation = False

        # Use GPU?
        #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = torch.device("cpu")
        print("Training on Device: ", self.device)
        if self.load_model is not True:
            self.writer = SummaryWriter()
        if  self.compute_validation == True:
            self.writer_val = SummaryWriter('Validation')

        # Get data
        data = ExtractH5Data.DataExtractor().returnData()

        # Get training-data, validation-data and test-data
        train_data, val_data, train_labels, val_labels = train_test_split(np.asarray(data[0]), np.asarray(data[1]),
                                                                          test_size=0.20, random_state=8)
        test_data = np.asarray(data[2])
        test_labels = np.asarray(data[3])

        print('\n\ntrain size: ',train_data.shape, train_labels.shape ,'\nVal size: ', val_data.shape, val_labels.shape,'\ntest size: ' ,test_data.shape, test_labels.shape)

        #set the dataloader
        dataset_train = TensorDataset(torch.tensor(train_data), torch.tensor(train_labels))
        dataset_val = TensorDataset(torch.tensor(val_data), torch.tensor(val_labels))
        dataset_test = TensorDataset(torch.tensor(test_data), torch.tensor(test_labels))

        self.dataloader = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True, drop_last=True)
        self.dataloader_val = DataLoader(dataset_val, batch_size=self.batch_size, shuffle=True, drop_last=True)
        self.dataloader_test = DataLoader(dataset_test, batch_size=self.batch_size, shuffle=False, drop_last=True)

        #  Network:
        self.net = PointNetSeg(self.device).to(self.device)


        # Optimizer:
        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)

        # Loss:
        self.loss = PointNetLoss(self.device)

        # Load Model?
        if self.load_model and os.path.isfile(self.model_path):
            self.net.load_state_dict(torch.load(self.model_path))
            self.net.eval()
            print("Loaded Path: ", self.model_path)



    def train(self):
        i = 1
        for epoch in range(self.n_epochs):
            running_loss = 0.0
            running_acc = 0.0
            #Training Loop:
            self.net.train()
            for i, (points, target) in enumerate(self.dataloader, start=1):
                #if torch.cuda.is_available():
                    points = points.to(self.device)
                    target = target.to(self.device, dtype = torch.int64)

                    # Compute Network Output
                    pred, A = self.net(points)

                    # Compute Loss
                    loss = self.loss(target, pred, A)

                    # Optimize:
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                    running_loss += loss.item()


                    pred_ = torch.argmax(pred, 1)
                    acc = ExtractH5Data.DataExtractor.CalculateACC(self, prediction=pred_, label=target, batchSize=self.batch_size)
                    running_acc += acc.item()


            print("Epoch: %d, Error Loss: %f, acc: %f" % (epoch, running_loss / i, running_acc / i), '%')
            self.writer.add_scalar("Loss/train", running_loss / i, epoch)
            self.writer.add_scalar('Accuracy', running_acc / i, epoch)

            #Save the model:
            torch.save(self.net.state_dict(), self.model_path)

            # Validate:
            self.net.eval()
            val_loss = 0
            val_counter = 0
            if self.compute_validation:
                for i, (points, target) in enumerate(self.dataloader_val):
                    #if torch.cuda.is_available(): # if we don´t do this we will run out of memory
                        points = points.to(self.device)
                        target = target.to(self.device, dtype = torch.int64)

                        pred, A = self.net(points)
                        loss = self.loss(target, pred, A)
                        val_loss += loss
                        if i % 100 == 0:
                            print('In validation: ', points.shape, target.shape, pred.shape)
                            print("Epoch: %d, i: %d, Validation Loss: %f" % (epoch, i, val_loss))
                            self.writer_val.add_scalar("Loss/train", loss, epoch)

                            val_counter += 1
        self.writer.flush()



    def mIoU(self):
        for i, (points, target) in enumerate(self.dataloader_test):
            points = points.to(self.device)
            target = target.to(self.device, dtype = torch.int64)
            pred, _ = self.net(points)
            #Find arg max of prediction:
            #max_ = torch.max(pred, 1)[1]
            max_ = torch.argmax(pred,1)
            print(max_[0:200])
            #print(pred.data.argmax(2)[1].shape)
            input("...")
            #if i == 5:
                #ExtractH5Data.DataExtractor.Visualize_shapeInColors(self, points=points[1,:,:], predictions=max_[1], labels=target[1])
            #    input("...")

if __name__ == "__main__":

    trainer = Trainer()
    #trainer.train()
    trainer.mIoU()

This is the data extractor file:

import numpy as np
import h5py
import os
from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D
import plotly.graph_objects as go
import sys


class DataExtractor:

    def getDataFiles(self, path):
        return [line.rstrip() for line in open(path)]

    def loadDataFile(self, filename, path2):
        return self.load_h5(filename, path2)

    def load_h5(self, h5_filename, path2):
        f = h5py.File(os.path.join(path2, h5_filename), 'r')
        data = f['data'][:]
        label = f['label'][:]
        return (data, label)


    def returnData(self):
        data = self.GetData()
        return data


    def GetData(self):
        PATH = '/Users/Mikke/PycharmProjects/HDF5_data/indoor3d_sem_seg_hdf5_data'
        path2 = '/Users/Mikke/PycharmProjects/HDF5_data'
        ALL_FILES = self.getDataFiles(os.path.join(PATH, 'all_files.txt'))
        room_filelist = [line.rstrip() for line in open(os.path.join(PATH, 'room_filelist.txt'))]
        print(len(ALL_FILES))

        data_batch_list = []
        label_batch_list = []
        counter = 0
        for h5_filename in ALL_FILES:
            if counter < 10:  # If more data is needed up the number!
                print(h5_filename)
                data_batch, label_batch = self.loadDataFile(h5_filename, path2)
                data_batch_list.append(data_batch)
                label_batch_list.append(label_batch)
                counter += 1
        data_batches = np.concatenate(data_batch_list, 0)
        label_batches = np.concatenate(label_batch_list, 0)

        test_area = 'Area_' + str(4)

        train_idxs = []
        test_idxs = []
        for i, room_name in enumerate(room_filelist):
            if i < len(data_batches):
                if test_area in room_name:
                    test_idxs.append(i)
                else:
                    train_idxs.append(i)

        train_data = data_batches[train_idxs, ...]
        train_label = label_batches[train_idxs]
        test_data = data_batches[test_idxs, ...]
        test_label = label_batches[test_idxs]

        # Correct labels into our own desire
        np.set_printoptions(threshold=sys.maxsize)

        #reshape data to work with PoinetNet Architecture
        train_data = train_data[:, :, 0:3]
        train_data = train_data.reshape(-1, 3, 4096)

        test_data = test_data[:, :, 0:3]
        test_data = test_data.reshape(-1, 3, 4096)

        return train_data, train_label, test_data, test_label


    def FindDistributionOfPoints(self, labels):
        Label4 = (labels == 5).sum()
        Label0 = (labels == 6).sum()
        Label1 = (labels == 2).sum()
        Label2 = (labels == 7).sum()
        #totalPoints = len(labels)*len(labels[1])
        totalPoints = len(labels)
        print(totalPoints)
        print('label sizes: \n', 'Window 0: ',(Label4/totalPoints)*100,'% door 1: ', (Label0/totalPoints)*100,'% wall 2: ',
              (Label1/totalPoints)*100,'% table 3: ', (Label2/totalPoints)*100,'% other 4: ', 100 * ((totalPoints - Label4 - Label0 - Label1 - Label2) / totalPoints), "%")




    def Visualize_shapeInColors(self, predictions, points, labels):

        points = points.cpu().numpy()
        predictions = predictions.cpu().numpy()
        labels = labels.cpu().numpy()
        clouds = points.reshape(4096, 3)
        DataExtractor.FindDistributionOfPoints(self, predictions)
        DataExtractor.FindDistributionOfPoints(self, labels)

        # Calculate acc
        acc = DataExtractor.CalculateACC(self, prediction=predictions, label=labels)

        print(f'Accuracy is: {acc}%')
        window = []
        door = []
        wall = []
        table = []
        remaining = []
        predictions = labels
        print(predictions[0:500])
        for i, cloud in enumerate(clouds):
            if predictions[i] == 5:
                window.append(cloud)
            elif predictions[i] == 6:
                door.append(cloud)
            elif predictions[i] == 2:
                wall.append(cloud)
            elif predictions[i] == 7:
                table.append(cloud)
            else:
                remaining.append(cloud)


        print('remaining size: ', np.asarray(remaining).shape, 'ceiling: ', np.asarray(window).shape, 'floor: ',
              np.asarray(door).shape, 'wall: ', np.asarray(wall).shape, 'table: ', np.asarray(table).shape)

        fig = go.Figure()
        if len(remaining) > 0:
            fig.add_trace(go.Scatter3d(x=np.asarray(remaining)[:, 0],
                                       y=np.asarray(remaining)[:, 1],
                                       z=np.asarray(remaining)[:, 2],
                                       name="Remaining", mode='markers', marker=dict(
                                       color='rgba(255, 255, 255, 0.5)', size=2,
                                       ), ))

        if len(window) > 0:
            fig.add_trace(go.Scatter3d(x=np.asarray(window)[:,0], y=np.asarray(window)[:,1], z=np.asarray(window)[:,2],
                                          name="Window", mode='markers', marker=dict(color='rgba(255, 0, 0, 0.5)', size=2,
                                          ), ))
        if len(door) > 0:
            fig.add_trace(go.Scatter3d(x=np.asarray(door)[:,0], y=np.asarray(door)[:,1], z=np.asarray(door)[:,2],
                                     name="Door", mode='markers', marker=dict(color='rgba(0, 255, 0, 0.5)', size=2,
                                                                   ), ))

        if len(wall) > 0:
            fig.add_trace(go.Scatter3d(x=np.asarray(wall)[:,0], y=np.asarray(wall)[:,1], z=np.asarray(wall)[:,2],
                                     name="Wall", mode='markers', marker=dict(color='rgba(0, 0, 255, 0.5)', size=2,
                                                                   ), ))

        if len(table) > 0:
            fig.add_trace(go.Scatter3d(x=np.asarray(table)[:,0], y=np.asarray(table)[:,1], z=np.asarray(table)[:,2],
                                     name="Table", mode='markers', marker=dict(color='rgba(255, 0, 255, 0.5)', size=2,
                                                                   ), ))

        fig.show()



    def CalculateACC(self, prediction, label):
        correctGuess = (prediction == label).sum()



        return 100 * (correctGuess / prediction.size)

This is the Network:

import torch
from torch import nn
import torch.nn.functional as F


class mlp(nn.Module):
    def __init__(self, in_size, out_size, k_size=1, batchnorm=True):
        """ 
        Creates a mlp layer as described in the paper.

        in_size: input size of the mlp
        out_size: output size of the mlp
        relu: apply relu
        batchnorm: apply norm 
        """

        super(mlp, self).__init__()
        self.batchnorm = batchnorm
        self.conv = nn.Conv1d(in_size, out_size, k_size)
        self.bn = nn.BatchNorm1d(out_size)

    def forward(self, x):
        if self.batchnorm:
            return F.relu(self.bn(self.conv(x)))
        else:
            return self.conv(x)


class fc(nn.Module):
    def __init__(self, in_size, out_size, k_size=1, batchnorm=True, dropout=False, dropout_p=0.7):
        """ 
        Creates a fully connected layer as described in the paper.

        in_size: input size of the mlp
        out_size: output size of the mlp
        relu: apply relu
        batchnorm: apply norm 
        """
        super(fc, self).__init__()
        self.batchnorm = batchnorm
        self.dropout = dropout

        self.fc = nn.Linear(in_size, out_size)
        self.bn = nn.BatchNorm1d(out_size)
        self.dp = nn.Dropout(p=dropout_p)

    def forward(self, x):
        if self.batchnorm and not self.dropout:
            return F.relu(self.bn(self.fc(x)))
        elif self.batchnorm and self.dropout:
            return F.relu(self.bn(self.dp(self.fc(x))))
        elif not self.batchnorm:
            return self.fc(x)


class TNet3(nn.Module):
    def __init__(self, device):
        super(TNet3, self).__init__()
        self.device = device

        self.mlp1 = mlp(3, 64)
        self.mlp2 = mlp(64, 128)
        self.mlp3 = mlp(128, 1024)

        self.fc1 = fc(1024, 512)
        self.fc2 = fc(512, 256, dropout=True)
        self.fc3 = fc(256, 9)

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.mlp1(x)
        x = self.mlp2(x)
        x = self.mlp3(x)
        x = torch.max(x, 2)[0]

        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        # For stability
        iden = torch.eye(3, 3).repeat(batch_size, 1, 1)
        if x.is_cuda:
            iden = iden.to(self.device)

        x = x.view(-1, 3, 3)
        x = x + iden

        return x


class TNet64(nn.Module):
    def __init__(self, device):
        super(TNet64, self).__init__()

        self.device = device

        self.mlp1 = mlp(64, 64)
        self.mlp2 = mlp(64, 128)
        self.mlp3 = mlp(128, 1024)

        self.fc1 = fc(1024, 512)
        self.fc2 = fc(512, 256, dropout=True)
        self.fc3 = fc(256, 64 * 64)

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.mlp1(x)
        x = self.mlp2(x)
        x = self.mlp3(x)
        x = torch.max(x, 2)[0]

        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        # For stability
        iden = torch.eye(64, 64).repeat(batch_size, 1, 1)
        if x.is_cuda:
            iden = iden.to(self.device)

        x = x.view(-1, 64, 64)
        x = x + iden

        return x


class PointNetSeg(nn.Module):
    def __init__(self, device, m=13):
        """
        m: number of classes which a single point can be classified into
        """
        super(PointNetSeg, self).__init__()

        self.device = device
        self.m = m

        self.TNet3 = TNet3(self.device)
        self.TNet64 = TNet64(self.device)

        self.mlp1 = mlp(3, 64)
        self.mlp2 = mlp(64, 64)
        self.mlp3 = mlp(64, 64)
        self.mlp4 = mlp(64, 128)
        self.mlp5 = mlp(128, 1024)

        self.mlp6 = mlp(1088, 512)
        self.mlp7 = mlp(512, 256)
        self.mlp8 = mlp(256, 128)
        self.mlp9 = mlp(128, self.m, batchnorm=False)


    def forward(self, x):
        #  input transform:
        x_ = x.clone()
        T3 = self.TNet3(x_)
        x = torch.matmul(T3, x)
        #  mlp (64,64):
        x = self.mlp1(x)
        x = self.mlp2(x)

        # feature transform:
        x_ = x.clone()
        T64 = self.TNet64(x_)

        x = torch.matmul(T64, x)

        x_feature = x.clone()

        #  mlp (64,128,1024):
        x = self.mlp3(x)
        x = self.mlp4(x)
        x = self.mlp5(x)

        x_globfeat = torch.max(x, 2, keepdim=True)[0]

        #  Concatenate global and local features
        x_globfeat = x_globfeat.expand(-1, -1, x_feature.shape[2])
        x = torch.cat((x_feature, x_globfeat), dim=1)  # here we concatenate the global features with the 64 features

        x = self.mlp6(x)
        x = self.mlp7(x)
        x = self.mlp8(x)
        x = self.mlp9(x)

        return x, T64


And lastly the loss function file:

from torch import nn

class PointNetLoss(nn.Module):
    def __init__(self, device, w=0.0001):
        super(PointNetLoss, self).__init__()
        self.w = w
        self.nll_loss = nn.CrossEntropyLoss()
        self.device = device

    def forward(self, gt, pr, A_):
        A = A_.clone()
        # Orthogonality constraint
        orth = torch.norm(torch.eye(A.shape[1]).to(self.device) - torch.matmul(A, A.transpose(1, 2)))
        loss = self.nll_loss(pr, gt) + self.w * orth
        return loss



I hope that someone can help me :)



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source