'How to apply the histogram function in pytorch to a specific axis?
I would like to use the torch.histc
function to different samples in my training batch.
Here is an example:
>>> tt2 = torch.from_numpy(np.array([[-0.2, 1, 0.21], [-0.1, 0.32, 0.2]]))
>>> tt3 = torch.from_numpy(np.array([[-0.8, 0.6, 0.1], [-0.6, 0.5, 0.4]]))
>>> t = torch.cat((tt2, tt3), 1)
>>> t
tensor([[-0.2000, 1.0000, 0.2100, -0.8000, 0.6000, 0.1000],
[-0.1000, 0.3200, 0.2000, -0.6000, 0.5000, 0.4000]],
dtype=torch.float64)
>>> torch.histc(t, bins=1, min=0, max=5)
tensor([8.], dtype=torch.float64)
However, I don't want to apply the histogram function for all the values in t
, I rather expect something like this:
>>> torch.histc(torch.tensor([[-0.2000, 1.0000, 0.2100, -0.8000, 0.6000, 0.1000]]), bins=1, min=0, max=5)
tensor([4.])
>>> torch.histc(torch.tensor([[-0.1000, 0.3200, 0.2000, -0.6000, 0.5000, 0.4000]]), bins=1, min=0, max=5)
tensor([4.])
>>>
And, finally, I want to aggregate all the histograms in the same tensor: tensor([[4.], [4.]])
.
Thanks in advance
Solution 1:[1]
You can try this:
import torch
torch.manual_seed(1)
bins = 1
t = torch.rand((2, 6))
tuple_rows = torch.tensor_split(t, t.shape[0], dim=0)
final_tensor = torch.empty((t.shape[0],bins))
for i,row in enumerate(tuple_rows):
final_tensor[i] = torch.histc(row, bins=1, min=0, max=5)
final_tensor : tensor([[6.], [6.]])
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 |