'How to get tensor from PyTorch split()
PyTorch's split function returns back a tuple of tensors. But I need to batch matrix multiply the result. Is there an easy way to split a tensor and get back a tensor? This is what I tried:
m = [[2, 3, 5, 7],
[11, 13, 17, 19],
[23, 29, 31, 37],
[41, 43, 47, 53]]
m_split = torch.tensor(m).split(2, dim=1)
torch.tensor([[[2, 3, 5, 7]]]).matmul(m_split)
This gives me an error because m_split is a tuple of tensors rather than being a tensor. Is there a view or reshape call I can make instead?
Solution 1:[1]
i think you can do as following
m = [[2, 3, 5, 7],
[11, 13, 17, 19],
[23, 29, 31, 37],
[41, 43, 47, 53]]
m_split = torch.tensor(m).tensor_split(2, dim=1)
m_split=torch.stack(list(m_split), dim=0)
torch.tensor([[[2, 3, 5, 7]]]).matmul(m_split)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | noob |
