'Pytorch: Assign values from one mask to another, masked by itself

I have a mask active that tracks batches that still have not terminated in a recurrent process. It's dimension is [batch_full,], and it's true entries show which elements need to still be used in current step. The recurrent process generates another mask, terminated, which has as many elements as true values in active mask. Now, I want to take values from ~terminated and put them back into active, but at the right indices. Basically I want to do:

import torch

active = torch.ones([4,], dtype=torch.bool)
active[:2] = torch.tensor(False)

terminated = torch.tensor([True, False])

active[active] = ~terminated

print(active)  # expected [F, F, F, T]

However, I get error:

RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.

How can I do the described above operation in an effective way?



Solution 1:[1]

There are a few solutions, I will also give their speed as measured by timeit, 10k repetitions, on 2021 macbook pro.

The simplest solution, taking 0.260s:

active[active.clone()] = ~terminated

We can use masked_scatter_ inplace operation for abt. 2x speedup (0.136s):

active.masked_scatter_(
        active,
        ~terminated,
    )

Out of place operation, taking 0.161s, would be:

active = torch.masked_scatter(
        active,
        active,
        ~terminated,
    )

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1