'Fastest, best (fastest) way to modify data in in a pytorch loss function?
I want to experiment with creating a modified Loss function for 4 channel image data.
What is the best way to split torch.Size([64, 4, 128, 128])
to
torch.Size([64, 3, 128, 128]) torch.Size([64, 1, 128, 128])
Solution 1:[1]
I was able to resolve this myself by using the Split function.
Given an Image based Tensor like: torch.Size([64, 4, 128, 128])
You can split on dim 1 and given a static length.
self.E1 = torch.split(self.E, 3, 1)
print(self.E1[0].shape);
print(self.E1[1].shape);
Gives:
torch.Size([64, 4, 128, 128])
torch.Size([64, 3, 128, 128])
torch.Size([64, 1, 128, 128])
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Rick De |
