'How to get tensor from PyTorch split()

PyTorch's split function returns back a tuple of tensors. But I need to batch matrix multiply the result. Is there an easy way to split a tensor and get back a tensor? This is what I tried:

m = [[2, 3, 5, 7],
     [11, 13, 17, 19],
     [23, 29, 31, 37],
     [41, 43, 47, 53]]
m_split = torch.tensor(m).split(2, dim=1)
torch.tensor([[[2, 3, 5, 7]]]).matmul(m_split)

This gives me an error because m_split is a tuple of tensors rather than being a tensor. Is there a view or reshape call I can make instead?



Solution 1:[1]

i think you can do as following

m = [[2, 3, 5, 7],
     [11, 13, 17, 19],
     [23, 29, 31, 37],
     [41, 43, 47, 53]]
m_split = torch.tensor(m).tensor_split(2, dim=1)
m_split=torch.stack(list(m_split), dim=0)
torch.tensor([[[2, 3, 5, 7]]]).matmul(m_split)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 noob