'Why do we slicing the tgt_input before passing it into the transformer?
I'm going through the PyTorch's Language Translation with the Transformer and Torchtext tutorial, and I notice that tgt_batch gets sliced before feeding it into the model. what purpose does this serve? I know that the target sequence needs to be masked, but from the code, it seems like the masking is already handled, so there doesn't seem to be a need for it.
def train_epoch(model, optimizer):
model.train()
losses = 0
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
for src, tgt in train_dataloader:
src = src.to(DEVICE)
tgt = tgt.to(DEVICE)
tgt_input = tgt[:-1, :]
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
optimizer.zero_grad()
tgt_out = tgt[1:, :]
loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
loss.backward()
optimizer.step()
losses += loss.item()
return losses / len(train_dataloader)
Tutorial link: https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#Transformer.forward
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
