'Loop optimization in Pytorch
I'm trying to find a way to prevent a painfully slow for loop in Pytorch. Basically, I have a tensor, and I want to split it up into pieces and feed those pieces into my model, similar in spirit to a grouped convolution of sorts.
self.C = C
self.block = Block(C, 3, 64)
def forward(self, x):
x_shape = x.shape
x = torch.flatten(x, start_dim=1, end_dim=-1).unsqueeze(1)
x = torch.split(x, self.C, -1)
attention = []
for i in x:
attended = self.block(i)
attention.append(attended)
attention = torch.stack(attention, 1)
Small values of C, alongside a large tensor, I think, makes this operation surprisingly much slower, due to the Python for-loop the above code runs through. However, when I exchange the batch dimension for a 'C' dimension and loop through the batch dimension instead, this causes significant speedups, however still feels hacky to me, and might still prove to be slow with a large enough batch size. I'd like a method to fix this while still keeping the batch-dim intact and avoiding the for loop. What I guess I'm looking for is a method to add a second batch dimension to my model, or something equivalent.
Is there any way to fix this issue other than the slightly-hacky method described above?
EDIT: MWE: (Pretend like the single linear layer is something like a split Attention layer...)
import torch
class Net(torch.nn.Module):
def __init__(self, split_size):
super().__init__()
self.split_size = split_size
self.linear = torch.nn.Linear(split_size, split_size)
def forward(self, x):
#Slow implementation:
#Input is B,C,H and is flattened to B,C*H.
y = x.flatten(start_dim=1, end_dim=-1)
y_split = torch.split(y, self.split_size, 1) #Tensor is split and each piece is fed into the model...
outs = []
for i in y_split:
i_out = self.linear(i)
outs.append(i_out)
y = torch.cat(outs, 1)
print(y.shape)
#Fast implementation using batch dims, but possibly slower for large batches...
y = x.flatten(start_dim=1, end_dim=-1)
y_split = torch.split(y, self.split_size, 1)
y = torch.stack(y_split, 0)
outs = []
for i in torch.split(y, 1, 1):
i_out = self.linear(i.squeeze(1)).unsqueeze(0)
outs.append(i_out)
y = torch.cat(outs, 0)
y = y.flatten(start_dim=1, end_dim=-1)
print(y.shape)
return y
if __name__ == "__main__":
net = Net(32)
net(torch.randn(256,3,32,32))
net(torch.randn(32,3,32,32))
Solution 1:[1]
I think 2D input will be processed in each channel with the same network layer.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | CSDL |
