'Fastest, best (fastest) way to modify data in in a pytorch loss function?

I want to experiment with creating a modified Loss function for 4 channel image data.

What is the best way to split torch.Size([64, 4, 128, 128])

to

torch.Size([64, 3, 128, 128]) torch.Size([64, 1, 128, 128])



Solution 1:[1]

I was able to resolve this myself by using the Split function.

Given an Image based Tensor like: torch.Size([64, 4, 128, 128])

You can split on dim 1 and given a static length.

self.E1 = torch.split(self.E, 3, 1)
print(self.E1[0].shape);
print(self.E1[1].shape);

Gives:

torch.Size([64, 4, 128, 128])
torch.Size([64, 3, 128, 128])
torch.Size([64, 1, 128, 128])

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Rick De