'How to select specific labels in pytorch MNIST dataset

I'm trying to create dataloaders using only a specific digit from PyTorch Mnist dataset

I already tried to create my own Sampler but it doesn't work and I'm not sure I'm using correctly the mask.

class YourSampler(torch.utils.data.sampler.Sampler):

    def __init__(self, mask):

        self.mask = mask


    def __iter__(self):

        return (self.indices[i] for i in torch.nonzero(self.mask))


    def __len__(self):

        return len(self.mask)


mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)   

mask = [True if mnist[i][1] == 5 else False for i in range(len(mnist))]

mask = torch.tensor(mask)   

sampler = YourSampler(mask)

trainloader = torch.utils.data.DataLoader(mnist, batch_size=4, sampler = sampler, shuffle=False, num_workers=2)

So far I had many different types of errors. For this implementation, it's "Stop Iteration". I feel like this is very easy/stupid but I can't find a simple way to do it. Thank you for your help!



Solution 1:[1]

The easiest option I can think of is to reduce the data set in-place:

indices = dataset.targets == 5 # if you want to keep images with the label 5
dataset.data, dataset.targets = dataset.data[indices], dataset.targets[indices]

Solution 2:[2]

StopIteration is raised when your iterator is exhausted. Are you sure your mask is working correctly? it seems like you pass list of boolean values, yet torch.nonzero would expect floats or ints.

You should write:

mask = [1 if mnist[i][1] == 5 else 0 for i in range(len(mnist))]

You should also need to pass the dataset to your sampler such as:

sampler = YourSampler(dataset, mask=mask)

with this class definition

class YourSampler(torch.utils.data.sampler.Sampler):

    def __init__(self, dataset, mask):

        self.mask = mask
        self.dataset = dataset
...

For more details, you can refer to pytorch documentation(which shows the source code) to see how they implemented more advanced samplers:https://pytorch.org/docs/stable/_modules/torch/utils/data/sampler.html#SequentialSampler

Solution 3:[3]

You could also use torch.utils.data.Subset as following:

# For indices 5, 6 and 7
indices = [idx for idx, target in enumerate(dataset.targets) if target in [5, 6, 7]]
dataloader = torch.utils.data.DataLoader(Subset(dataset, indices),
                                         batch_size=BATCH_SIZE, 
                                         drop_last=True)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Gepeto97
Solution 2 QuantumLicht
Solution 3 hlzl