'Loop optimization in Pytorch

I'm trying to find a way to prevent a painfully slow for loop in Pytorch. Basically, I have a tensor, and I want to split it up into pieces and feed those pieces into my model, similar in spirit to a grouped convolution of sorts.

        self.C = C
        self.block = Block(C, 3, 64)
    def forward(self, x):
        x_shape = x.shape
        x = torch.flatten(x, start_dim=1, end_dim=-1).unsqueeze(1)
        x = torch.split(x, self.C, -1)
        attention = []
        for i in x:
            attended = self.block(i)
            attention.append(attended)
        attention = torch.stack(attention, 1)

Small values of C, alongside a large tensor, I think, makes this operation surprisingly much slower, due to the Python for-loop the above code runs through. However, when I exchange the batch dimension for a 'C' dimension and loop through the batch dimension instead, this causes significant speedups, however still feels hacky to me, and might still prove to be slow with a large enough batch size. I'd like a method to fix this while still keeping the batch-dim intact and avoiding the for loop. What I guess I'm looking for is a method to add a second batch dimension to my model, or something equivalent.

Is there any way to fix this issue other than the slightly-hacky method described above?

EDIT: MWE: (Pretend like the single linear layer is something like a split Attention layer...)

import torch

class Net(torch.nn.Module):
    def __init__(self, split_size):
        super().__init__()
        self.split_size = split_size
        self.linear = torch.nn.Linear(split_size, split_size)
    def forward(self, x):
        #Slow implementation:

        #Input is B,C,H and is flattened to B,C*H.
        y = x.flatten(start_dim=1, end_dim=-1)

        y_split = torch.split(y, self.split_size, 1) #Tensor is split and each piece is fed into the model...

        outs = []

        for i in y_split:
            i_out = self.linear(i)
            outs.append(i_out)
        y = torch.cat(outs, 1)
        print(y.shape)

        #Fast implementation using batch dims, but possibly slower for large batches...
        y = x.flatten(start_dim=1, end_dim=-1)
        y_split = torch.split(y, self.split_size, 1)

        y = torch.stack(y_split, 0)
        outs = []
        for i in torch.split(y, 1, 1):
            i_out = self.linear(i.squeeze(1)).unsqueeze(0)
            outs.append(i_out)
        y = torch.cat(outs, 0)
        y = y.flatten(start_dim=1, end_dim=-1)
        print(y.shape)
        return y
        
if __name__ == "__main__":
    net = Net(32)
    net(torch.randn(256,3,32,32))
    net(torch.randn(32,3,32,32))


Solution 1:[1]

I think 2D input will be processed in each channel with the same network layer.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 CSDL