'Is there a PyTorch transform to go from 1 Channel data to 3 Channels?

I am using the emnist data set via the PyTorch datasets together with a neural network that expects a 3 Channel input.

I would like to use PyTorch transforms to copy my 1D greyscale into 3D so I can use the same net for 1D and 3D data.

Which transform can I use? Or how would I extend PyTorch transforms as suggested here: https://stackoverflow.com/a/50530923/18895809



Solution 1:[1]

You can simply use:

import torch

img = torch.zeros((1, 200, 200))            # img shape = (1, 200, 200)
img2 = torch.cat([img, img, img], dim=0)    # img shape2 = (3, 200, 200)

If you prefer, you can even code your own transformation based on the above snippet. You need to create a simple callable class, just as described by the wiki at: https://pytorch.org/tutorials/recipes/recipes/custom_dataset_transforms_loader.html

Solution 2:[2]

Training-wise I recommend using img.expand(...) as it does not allocate new memory see here instead of concatenation torch.cat. While doing so keep in mind that a 3 channel image (possibly RGB) is structurally quiet different from a gray scale one (I suspect you may have some degradation in your results.)

import torch
img = torch.zeros((1, 200, 200))
img = img.expand(3,*img.shape[1:])

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Deusy94
Solution 2