'Splitting a Tensor channelwise

I am dumping a tensor of size [1,3,224,224] to a file and would like to split into 3 tensors of size [1,1,224,224], one for each RGB channel and dump them into 3 separate files. How do I implement this?



Solution 1:[1]

I think the simplest way is by a loop:

for c in range(x.shape[1]):
  torch.save(x[:, c:c+1, ...], f'channel{c}.pth')

Note the indexing of the second (channel) dimension: you want the saved tensor to have a singleton channel dimension. If you were to index it using x[:, c, ...] you will get a tensor of shape [1, 224, 224] (no channel dimension at all).

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Shai