'All possible concatenations of two tensors in PyTorch

Suppose I have two tensors S and T defined as:

S = torch.rand((3,2,1))
T = torch.ones((3,2,1))

We can think of these as containing batches of tensors with shapes (2, 1). In this case, the batch size is 3.

I want to concatenate all possible pairings between batches. A single concatenation of batches produces a tensor of shape (4, 1). And there are 3*3 combinations so ultimately, the resulting tensor C must have a shape of (3, 3, 4, 1).

One solution is to do the following:

for i in range(S.shape[0]):
  for j in range(T.shape[0]):
    C[i,j,:,:] = torch.cat((S[i,:,:],T[j,:,:]))

But the for loop doesn't scale well to large batch sizes. Is there a PyTorch command to do this?



Solution 1:[1]

I don't know of any command out-of-the-box that does such operation. However, you can pull it off in a straightforward way using a single matrix multiplication.


The trick is to construct a tensor containing all pairs of batch elements by starting from already stacked S,T tensor. Then by multiplying it with a properly chosen mask tensor... In this method, keeping track of shapes and dimension sizes is essential.

  1. The stack is given by (notice the reshape, we essentially flatten the batch elements from S and T into a single batch axis on ST):

    >>> ST = torch.stack((S, T)).reshape(6, 2)
    >>> ST
    tensor([[0.7792, 0.0095],
            [0.1893, 0.8159],
            [0.0680, 0.7194],
            [1.0000, 1.0000],
            [1.0000, 1.0000],
            [1.0000, 1.0000]]
    # ST.shape = (6, 2)
    
  2. You can retrieve all (S[i], T[j]) pairs using range and itertools.product:

    >>> indices = torch.tensor(list(product(range(0, 3), range(3, 6))))
    tensor([[0, 3],
            [0, 4],
            [0, 5],
            [1, 3],
            [1, 4],
            [1, 5],
            [2, 3],
            [2, 4],
            [2, 5]])
    # indices.shape = (9, 2)
    
  3. From there, we construct one-hot-encodings of the indices using torch.nn.functional.one_hot:

    >>> mask = one_hot(indices).float()
    tensor([[[1., 0., 0., 0., 0., 0.],
             [0., 0., 0., 1., 0., 0.]],
    
            [[1., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 1., 0.]],
    
            [[1., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 1.]],
    
            [[0., 1., 0., 0., 0., 0.],
             [0., 0., 0., 1., 0., 0.]],
    
            [[0., 1., 0., 0., 0., 0.],
             [0., 0., 0., 0., 1., 0.]],
    
            [[0., 1., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 1.]],
    
            [[0., 0., 1., 0., 0., 0.],
             [0., 0., 0., 1., 0., 0.]],
    
            [[0., 0., 1., 0., 0., 0.],
             [0., 0., 0., 0., 1., 0.]],
    
            [[0., 0., 1., 0., 0., 0.],
             [0., 0., 0., 0., 0., 1.]]])
    # mask.shape = (9, 2, 6)
    
  4. Finally, we compute the matrix multiplication and reshape it to the final form:

    >>> (mask@ST).reshape(3, 3, 4, 1)
    tensor([[[[0.7792],
              [0.0095],
              [1.0000],
              [1.0000]],
    
             [[0.7792],
              [0.0095],
              [1.0000],
              [1.0000]],
    
             [[0.7792],
              [0.0095],
              [1.0000],
              [1.0000]]],
    
    
            [[[0.1893],
              [0.8159],
              [1.0000],
              [1.0000]],
    
             [[0.1893],
              [0.8159],
              [1.0000],
              [1.0000]],
    
             [[0.1893],
              [0.8159],
              [1.0000],
              [1.0000]]],
    
    
            [[[0.0680],
              [0.7194],
              [1.0000],
              [1.0000]],
    
             [[0.0680],
              [0.7194],
              [1.0000],
              [1.0000]],
    
             [[0.0680],
              [0.7194],
              [1.0000],
              [1.0000]]]])
    

I initially went with torch.einsum: torch.einsum('bf,pib->pif', ST, mask). But, later realized than that bf,pib->pif reduces nicely to a simple torch.Tensor.matmul operation if we switch the two operands: i.e. with pib,bf->pif (subscript b is reduced in the middle).

Solution 2:[2]

In numpy something called np.meshgrid is used.

https://stackoverflow.com/a/35608701/3259896

So in pytorch, it would be

torch.stack(
torch.meshgrid(x, y)
).T.reshape(-1,2)

Where x and y are your two lists. You can use any number. x, y , z, etc.

And then you reshape it to the number of lists you use.

So if you used three lists, use .reshape(-1,3), for four use .reshape(-1,4), etc.

So for 5 tensors, use

torch.stack(
torch.meshgrid(a, b, c, d, e)
).T.reshape(-1,5)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 SantoshGupta7