'What does Tensor[batch_mask, ...] do?
I saw this line of code in an implementation of BiLSTM:
batch_output = batch_output[batch_mask, ...]
I assume this is some kind of "masking" operation, but found little information on Google about the meaning of .... Please help:).
Original Code:
class BiLSTM(nn.Module):
def __init__(self, vocab_size, tagset, embedding_dim, hidden_dim,
num_layers, bidirectional, dropout, pretrained=None):
# irrelevant code ..........
def forward(self, batch_input, batch_input_lens, batch_mask):
batch_size, padding_length = batch_input.size()
batch_input = self.word_embeds(batch_input) # size: #batch * padding_length * embedding_dim
batch_input = rnn_utils.pack_padded_sequence(
batch_input, batch_input_lens, batch_first=True)
batch_output, self.hidden = self.lstm(batch_input, self.hidden)
self.repackage_hidden(self.hidden)
batch_output, _ = rnn_utils.pad_packed_sequence(batch_output, batch_first=True)
batch_output = batch_output.contiguous().view(batch_size * padding_length, -1)
####### HERE ##########
batch_output = batch_output[batch_mask, ...]
#########################
out = self.hidden2tag(batch_output)
return out
Solution 1:[1]
I assume that batch_mask is a boolean tensor. In that case, batch_output[batch_mask] performs a boolean indexing that selects the elements corresponding to True in batch_mask.
... is usually referred as ellipsis, and in the case of PyTorch (but also other NumPy-like libraries), it is a shorthand for avoiding repeating the column operator (:) multiple times. For example, given a tensor v, with v.shape equal to (2, 3, 4), the expression v[1, :, :] can be rewritten as v[1, ...].
I performed some tests and using either batch_output[batch_mask, ...] or batch_output[batch_mask] seems to work identically:
t = torch.arange(24).reshape(2, 3, 4)
# mask.shape == (2, 3)
mask = torch.tensor([[False, True, True], [True, False, False]])
print(torch.all(t[mask] == t[mask, ...])) # returns True
Solution 2:[2]
This statement masks the first dimension of batch_output with the indices contained by batch_mask. In practice, this means you are selecting some of the elements from the batch.
Here is a practical example:
>>> x = torch.rand(3,1,4,4)
tensor([[[[0.5216, 0.1122, 0.0396, 0.5824],
[0.7685, 0.5583, 0.2817, 0.9678],
[0.8878, 0.9477, 0.2554, 0.8261],
[0.2708, 0.3403, 0.7734, 0.2584]]],
[[[0.5471, 0.5031, 0.3906, 0.7554],
[0.1895, 0.3985, 0.7083, 0.7849],
[0.3128, 0.6733, 0.9223, 0.5345],
[0.2689, 0.9876, 0.1092, 0.7405]]],
[[[0.9834, 0.0276, 0.7114, 0.2872],
[0.3483, 0.2104, 0.1816, 0.5615],
[0.4323, 0.5329, 0.9198, 0.8647],
[0.9054, 0.5763, 0.7939, 0.8388]]]])
With a mask and masking operation:
>>> mask = [0, 2]
>>> x[mask]
tensor([[[[0.5216, 0.1122, 0.0396, 0.5824],
[0.7685, 0.5583, 0.2817, 0.9678],
[0.8878, 0.9477, 0.2554, 0.8261],
[0.2708, 0.3403, 0.7734, 0.2584]]],
[[[0.9834, 0.0276, 0.7114, 0.2872],
[0.3483, 0.2104, 0.1816, 0.5615],
[0.4323, 0.5329, 0.9198, 0.8647],
[0.9054, 0.5763, 0.7939, 0.8388]]]])
Where only the element at index 0 and 2 remains.
Note: x[mask] is identical to x[mask, ...], where the ellipsis is not necessary since all dimensions positioned will get all their indices selected by default.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | aretor |
| Solution 2 | Ivan |
