'Can you get the max over tensors with different dimensions using pytorch?

Given a pytorch tensor t of length, say, 15, is there a 'nice' way of getting the max values of disjoint subsets of this tensor? Specifically, given a list l=(5,4,6) I want the max of the first 5 elements of t, then the max of the next 4 elements, and the max of the final 6.

If the elements of l are equal, the tensor can be reshaped and the max for each row can be found in a single step. But I can't see a way of doing this nicely when the elements of l are different, without resorting to looping. It would be nice to find a parallelisable way of doing this.

For context, I'm working on a Reinforcement Learning problem in which my states are graphs. An action is just a vertex of the graph. I'm using DQN. When I sample a batch from the buffer, for each graph in the buffer I generate a Q value for each vertex and, for each graph, I want to find the maximum amongst these Q values. But when the graphs have different numbers of vertices I encounter the problem described above: I can't see a way of getting the max Q value for each graph without having to loop through the batch.

Here is an example of how to do this using a loop:

# Case when the subset lengths are unbalanced
q_values_tensor = torch.randint(10,(15,1))
lengths_list = [5,4,6]
max_q_values_unbalanced = torch.zeros((1,len(lengths_list)))
tensor_index=0
for list_index, length in enumerate(lengths_list):
    max_q_values_unbalanced[0, list_index] = torch.max(q_values_tensor[tensor_index:tensor_index+length])
    tensor_index += length

And an example where the subset lengths are all equal (I'm hoping there might be a solution that is similar to this):

# Case when the subset lengths are balanced (what I'd like to replicate)
q_values_tensor = torch.randint(10,(15,1))
lengths_list = [5,5,5]
reshaped_q_values_tensor = q_values_tensor.view(-1,lengths_list[0])
max_q_values_balanced = reshaped_q_values_tensor.max(1)[0]


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source