'How do I compute batched sample covariance in PyTorch?

Say I have data, a batched tensor of collections of data points of size (B, N, D) where B is my batch size, N is the number of data samples in each collection, and D is the length of my data vectors. I want to compute the sample mean and covariance for each collection of data points, but do it in batch.

To compute the mean I can do data.mean(dim=1) and I get a tensor of size (B, D) representing the mean of each collection. I assumed I'd be able to do a similar thing with torch.cov but it does not offer the ability to do it in batch. Is there another way to achieve this? I'm expecting to get a batch of covariance matrices of shape (B, D, D).



Solution 1:[1]

This does the trick:

def batch_cov(points):
    B, N, D = points.size()
    mean = points.mean(dim=1).unsqueeze(1)
    diffs = (points - mean).reshape(B * N, D)
    prods = torch.bmm(diffs.unsqueeze(2), diffs.unsqueeze(1)).reshape(B, N, D, D)
    bcov = prods.sum(dim=1) / (N - 1)  # Unbiased estimate
    return bcov  # (B, D, D)

Here is a script to test that it's computing the same thing that the non-batched PyTorch version computes:

import time
import torch

B = 10000
N = 50
D = 2
points = torch.randn(B, N, D)
start = time.time()
my_covs = batch_cov(points)
print("My time:   ", time.time() - start)

start = time.time()
torch_covs = torch.zeros_like(my_covs)
for i, batch in enumerate(points):
    torch_covs[i] = batch.T.cov()

print("Torch time:", time.time() - start)
print("Same?", torch.allclose(my_covs, torch_covs, atol=1e-7))

Which gives me:

My time:    0.00251793861318916016
Torch time: 0.2459864616394043
Same? True

I can't claim mine will be inherently faster than iteratively computing them, it seems as D gets bigger mine will slow down much more, so there's probably a nicer way to scale with data dimension size.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 adamconkey