'Can I define a function to replace torch.einsum' "sum" with any custom function?

I faced this problem.

I want to get a tensor C like this enter image description here where enter image description here I know torch.einsum can do this by

    torch.einsum('aijk,aijh->ajkh', f, g)

But i don't want to use dot product, i want to replace it with a custom function. there is a naive solution.


        N, C , H ,W = left.shape

        cost = torch.autograd.Variable( torch.FloatTensor(N,H,W,W) ).cuda()

        for i in range(H):
            for j in range(W):
                for k in range(W):
                    cost[:,i,j,k] = self.qatm_cost(feature1=left[:,:,i,j], feature2= right[:,:,i,k])
                    #print(left[:,:,i,j].shape)
                    #print(right[:,:,i,k].shape)

It's to show to use, how can make it faster in some way?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source