'How to scale a fixed sparse matrix by the value in a 1x1 tensor in pytorch?

Is it possible to scale a fixed sparse matrix by the value in a 1x1 tensor in pytorch?

For example, in code I'm working on I'm seeing the following issue:

>>> import torch
>>> sp_mat = torch.sparse_coo_tensor([[0,1,2],[0,1,2]],[1,1,1],(3,3))
>>> w = torch.tensor([0.5])
>>> sp_mat*w
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: mul operands have incompatible sizes

Is there a workaround? Ultimately I want to let the w variable be a learnable parameter, but cannot seem to find a way to get this operation to work when w is a tensor.

It works just fine if the weight is a float:

>>> import torch
>>> sp_mat = torch.sparse_coo_tensor([[0,1,2],[0,1,2]],[1,1,1],(3,3))
>>> y = 0.5
>>> sp_mat*y
tensor(indices=tensor([[0, 1, 2],
                       [0, 1, 2]]),
       values=tensor([0.5000, 0.5000, 0.5000]),
       size=(3, 3), nnz=3, layout=torch.sparse_coo)

Any suggestions? Thanks!



Solution 1:[1]

While unlearnable param is simple, to make the tensor learnable it has to be of the same shape as your data (hence requirements are 2x that memory unfortunately as 0D normal/sparse tensor seems not to be broadcasted correctly).

In this case w has to be recreated as sparse tensor, could be done like so (sp_mat is the same as t below):

w = torch.sparse_coo_tensor(
  t.indices(), 
  torch.full_like(
    t.values(), 
    0.5,
  ),
  t.shape, 
  requires_grad=True, 
)

Also thanks to Phoenix for pointing out my misreading of the question itself.

On the bad side, all values are independent in this case for every value in the sparse matrix which might not be what you are after. AFAIK there is no way for “learnable scalar” and sparse matrix to work together and further hacks would be needed (correct me if I'm wrong please).

EDIT: Just tested another approach, namely:

w = torch.tensor(0.5, requires_grad=True).to_sparse()

And multiplied it, and it seems to be a bug. You might want to open PyTorch issue about it.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1