'How to use yield in PyTorch __getitem__() with dataloader?

My future intention is to load my data into GPU memory during training, validation and testing only when needed. Since I have limited capacity go VRAM I decided to structure my dataset in such a way that:

  • it contains only the list of paths to the images I want to work with
  • actual image data is retrieved only during a __getitem__ call

Currently my images are rather small (something like 200x200px) and every image is split into patches (always 36x36px), each patch undergoing some sort of transformation. However in the future I will switch to larger images (think 1000x1000px or even more), in which case the number of patches will increase.

yield appears to be a good option here since it provides a generator that will load data as required. However I am new to yield and PyTorch so I am struggling to combine both.

Here is the dataset:

class CustomDataset(torch.utils.data.Dataset):
  def __init__(self, images):
    '''
    Dataset which loads image from a file and upon retrieval using __getitem__
    yield a pair of stacks with patches
    '''
    self.images = images

  def __len__(self):
    return len(self.images)


  def __getitem__(self, idx):
    # Load image from file
    img = cv2.imread(self.images[idx], cv2.IMREAD_GRAYSCALE)
    # Extract patches from image (list of images)
    patches = self.extract_patches(img)
    # Using original patches create two separate lists of equal length
    #  - list "patches_x" contains patches that have underwent transformation X
    #  - list "patches_y" contains patches that have underwent transformation Y
    # Each transformation is using OpenCV or numpy in general, so a final conversion
    # to Tensor is required using "from_numpy"
    patches_x = [from_numpy(transform_X(patch) for patch in patches]
    patches_y = [from_numpy(transform_Y(patch) for patch in patches]

    # Yield pair
    yield stack(patches_x), stack(patches_y)

The transformations are not important here (that's why I haven't included those). Every time a new item is retrieved, the image from the respective path is loaded (using OpenCV), split into patches and then two lists of patches get treated with different transformations (think resizing, blurring etc.). Finally the two lists are converted into stack structures and yielded.

My question is how to handle loading such a dataset using a Dataloader. I obviously need to use a custom collate_fn

def custom_collate(batch):
  # TODO

dataset_train = CustomDataset('./data/train/')
dataloader_train = DataLoader(dataset=dataset_train, batch_size=64, collate_fn=custom_collate)

If I just use a Dataloader instance with the default_collate

for sample_patches in iter(dataloader_train):
    print(sample_patches)

I get an error

Traceback (most recent call last):
  File "D:\Projects\networks\net\net.py", line 255, in <module>
    for sample_patches in iter(dataloader_train):
  File "D:\env\ml\lib\site-packages\torch\utils\data\dataloader.py", line 530, in __next__
    data = self._next_data()
  File "D:\env\ml\lib\site-packages\torch\utils\data\dataloader.py", line 570, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "D:\env\ml\lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "D:\env\satellite\lib\site-packages\torch\utils\data\_utils\collate.py", line 180, in default_collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'generator'>

which is understandable, since __getitem__() returns a generator due to the yield statement at the end and not tensors. Then again my use of yield may be incorrect altogether...

The network is defined as

from torch import nn
from collections import OrderedDict

class Network(nn.Module):
    def __init__(self) -> None:
        super(Network, self).__init__()
        self.model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(1, 64, 7, stride=1, padding=3)),
            ('relu1', nn.ReLU(True)),
            ('conv2', nn.Conv2d(64, 32, 5, stride=1, padding=2)),
            ('relu2', nn.ReLU(True)),
            ('conv3', nn.Conv2d(32, 32, 3, stride=1, padding=1)),
            ('relu3', nn.ReLU(True)),
            ('conv4', nn.Conv2d(32, 1, 3, stride=1, padding=1))
        ]))

    def forward(self, x):
        x = self.model(x)
        return x

UPDATE:

The problem with my current code is that batch_size is only affecting the number of images (as in number of image paths) I am loading since those directly connected to __len__() and __getitem__(). The patches are a by-product of __getitem__() hence the dataloader doesn't really know about those.

Example:

Let's have the following collate function

def custom_collate(batch):
    return batch

which doesn't really do anything but return the generator, returned by the yield in __getitem__().

Configuring the Dataloader

dataloader_train = DataLoader(dataset=dataset_train, batch_size=1, collate_fn=custom_collate)

and iterating over it

for sample in iter(dataloader_train):
    print(sample)

gives

[<generator object TEN_DataTrain.__getitem__ at 0x000002095B154D60>]
[<generator object TEN_DataTrain.__getitem__ at 0x000002095B154DD0>]
[<generator object TEN_DataTrain.__getitem__ at 0x000002095B154D60>]
[<generator object TEN_DataTrain.__getitem__ at 0x000002095B154DD0>]
...

where the number of [<generator object>] instances is equal to the number of image paths.

If I increase the batch_size to 2 I will get

[<generator object TEN_DataTrain.__getitem__ at 0x0000021E92684D60>, <generator object TEN_DataTrain.__getitem__ at 0x0000021E92684DD0>]
[<generator object TEN_DataTrain.__getitem__ at 0x0000021E92684E40>, <generator object TEN_DataTrain.__getitem__ at 0x0000021E92684EB0>]
[<generator object TEN_DataTrain.__getitem__ at 0x0000021E92684D60>, <generator object TEN_DataTrain.__getitem__ at 0x0000021E92684DD0>]
...

I guess this is how batching is supposed to work but in return this does not suit me since I might have few images that are large and thus yield many, many patches.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source