'Splitting a Tensor channelwise
I am dumping a tensor of size [1,3,224,224] to a file and would like to split into 3 tensors of size [1,1,224,224], one for each RGB channel and dump them into 3 separate files. How do I implement this?
Solution 1:[1]
I think the simplest way is by a loop:
for c in range(x.shape[1]):
torch.save(x[:, c:c+1, ...], f'channel{c}.pth')
Note the indexing of the second (channel) dimension: you want the saved tensor to have a singleton channel dimension. If you were to index it using x[:, c, ...] you will get a tensor of shape [1, 224, 224] (no channel dimension at all).
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Shai |
