'How to make an LSTM Bidirectional?
Question:
What changes to LSTMClassifier do I need to make, in order to have this LSTM work bidirectionally?
I'm basing my amendments on this disscuss.pytorch.org response.
I think the problem is in forward(). It learns from the last state of LSTM neural network, by slicing:
tag_space = self.classifier(lstm_out[:,-1,:])
Do I need to sum up or concatenate the values of the 2 layers/ directions?
Working Code:
from argparse import ArgumentParser
import torchmetrics
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSTMClassifier(nn.Module):
def __init__(self,
num_classes,
batch_size=10,
embedding_dim=100,
hidden_dim=50,
vocab_size=128):
super(LSTMClassifier, self).__init__()
initrange = 0.1
self.num_labels = num_classes
n = len(self.num_labels)
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.num_layers = 1
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.word_embeddings.weight.data.uniform_(-initrange, initrange)
self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, num_layers=self.num_layers, batch_first=True, bidirectional=True) # !
#self.classifier = nn.Linear(hidden_dim, self.num_labels[0])
self.classifier = nn.Linear(2 * hidden_dim, self.num_labels[0]) # !
def repackage_hidden(h):
"""Wraps hidden states in new Tensors, to detach them from their history."""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
def forward(self, sentence, labels=None):
embeds = self.word_embeddings(sentence)
# lstm_out, _ = self.lstm(embeds) # lstm_out - 2 tensors, _ - hidden layer
lstm_out, hidden = self.lstm(embeds)
# Calculate number of directions
self.num_directions = 2 if self.lstm.bidirectional == True else 1
# Extract last hidden state
# final_state = hidden.view(self.num_layers, self.num_directions, self.batch_size, self.hidden_dim)[-1]
final_state = hidden[0].view(self.num_layers, self.num_directions, self.batch_size, self.hidden_dim)[-1]
# Handle directions
final_hidden_state = None
if self.num_directions == 1:
final_hidden_state = final_state.squeeze(0)
elif self.num_directions == 2:
h_1, h_2 = final_state[0], final_state[1]
# final_hidden_state = h_1 + h_2 # Add both states (requires changes to the input size of first linear layer + attention layer)
final_hidden_state = torch.cat((h_1, h_2), 1) # Concatenate both states
print("len(final_hidden_state)", len(final_hidden_state))
print("len(labels)", len(labels))
# tag_space = self.classifier(hidden[:,0,:] + hidden[:,-1,:]) # ! # torch.flip(lstm_out[:,-1,:], [0, 1]) - 1 tensor
logits = F.log_softmax(final_hidden_state, dim=1) # tag_space
loss = None
if labels:
loss = F.cross_entropy(logits.view(-1, self.num_labels[0]), labels[0].view(-1))
return loss, logits
class LSTMTaggerModel(pl.LightningModule):
def __init__(
self,
num_classes,
class_map,
from_checkpoint=False,
model_name='last.ckpt',
learning_rate=3e-6,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.learning_rate = learning_rate
self.model = LSTMClassifier(num_classes=num_classes)
# self.model.load_state_dict(torch.load(model_name), strict=False) # !
self.class_map = class_map
self.num_classes = num_classes
self.valid_acc = torchmetrics.Accuracy()
self.valid_f1 = torchmetrics.F1()
def forward(self, *input, **kwargs):
return self.model(*input, **kwargs)
def training_step(self, batch, batch_idx):
x, y_true = batch
loss, _ = self(x, labels=y_true)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y_true = batch
_, y_pred = self(x, labels=y_true)
preds = torch.argmax(y_pred, axis=1)
self.valid_acc(preds, y_true[0])
self.log('val_acc', self.valid_acc, prog_bar=True)
self.valid_f1(preds, y_true[0])
self.log('f1', self.valid_f1, prog_bar=True)
def configure_optimizers(self):
'Prepare optimizer and schedule (linear warmup and decay)'
opt = torch.optim.Adam(params=self.parameters(), lr=self.learning_rate)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
return [opt], [sch]
def training_epoch_end(self, training_step_outputs):
avg_loss = torch.tensor([x['loss']
for x in training_step_outputs]).mean()
self.log('train_loss', avg_loss)
print(f'###score: train_loss### {avg_loss}')
def validation_epoch_end(self, val_step_outputs):
acc = self.valid_acc.compute()
f1 = self.valid_f1.compute()
self.log('val_score', acc)
self.log('f1', f1)
print(f'###score: val_score### {acc}')
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("OntologyTaggerModel")
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--learning_rate", default=2e-3, type=float)
return parent_parser
Runtime:
Global seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
| Name | Type | Params
---------------------------------------------
0 | model | LSTMClassifier | 77.4 K
1 | valid_acc | Accuracy | 0
2 | valid_f1 | F1 | 0
---------------------------------------------
77.4 K Trainable params
0 Non-trainable params
77.4 K Total params
0.310 Total estimated model params size (MB)
Validation sanity check: 0it [00:00, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-18-3f817f701f20> in <module>
11 """.split()
12
---> 13 run_training(args)
<ipython-input-5-bb0d8b014e32> in run_training(input)
66 shutil.copyfile(labels_file_orig, labels_file_cp)
67 trainer = pl.Trainer.from_argparse_args(args, callbacks=[checkpoint_callback], logger=loggers)
---> 68 trainer.fit(model, dm)
69 model_file = os.path.join(args.modeldir, 'last.ckpt')
70 trainer.save_checkpoint(model_file, weights_only=True)
~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
497
498 # dispath `start_training` or `start_testing` or `start_predicting`
--> 499 self.dispatch()
500
501 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
544
545 else:
--> 546 self.accelerator.start_training(self)
547
548 def train_or_test_or_predict(self):
~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
71
72 def start_training(self, trainer):
---> 73 self.training_type_plugin.start_training(trainer)
74
75 def start_testing(self, trainer):
~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
112 def start_training(self, trainer: 'Trainer') -> None:
113 # double dispatch to initiate the training loop
--> 114 self._results = trainer.run_train()
115
116 def start_testing(self, trainer: 'Trainer') -> None:
~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
605 self.progress_bar_callback.disable()
606
--> 607 self.run_sanity_check(self.lightning_module)
608
609 # set stage for logging
~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in run_sanity_check(self, ref_model)
858
859 # run eval step
--> 860 _, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches)
861
862 self.on_sanity_check_end()
~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py in run_evaluation(self, max_batches, on_epoch)
723 # lightning module methods
724 with self.profiler.profile("evaluation_step_and_end"):
--> 725 output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
726 output = self.evaluation_loop.evaluation_step_end(output)
727
~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/trainer/evaluation_loop.py in evaluation_step(self, batch, batch_idx, dataloader_idx)
164 model_ref._current_fx_name = "validation_step"
165 with self.trainer.profiler.profile("validation_step"):
--> 166 output = self.trainer.accelerator.validation_step(args)
167
168 # capture any logged information
~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/accelerators/accelerator.py in validation_step(self, args)
175
176 with self.precision_plugin.val_step_context(), self.training_type_plugin.val_step_context():
--> 177 return self.training_type_plugin.validation_step(*args)
178
179 def test_step(self, args):
~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in validation_step(self, *args, **kwargs)
129
130 def validation_step(self, *args, **kwargs):
--> 131 return self.lightning_module.validation_step(*args, **kwargs)
132
133 def test_step(self, *args, **kwargs):
<ipython-input-17-542f29e75b1a> in validation_step(self, batch, batch_idx)
104 def validation_step(self, batch, batch_idx):
105 x, y_true = batch
--> 106 _, y_pred = self(x, labels=y_true)
107 preds = torch.argmax(y_pred, axis=1)
108 self.valid_acc(preds, y_true[0])
~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
<ipython-input-17-542f29e75b1a> in forward(self, *input, **kwargs)
94
95 def forward(self, *input, **kwargs):
---> 96 return self.model(*input, **kwargs)
97
98 def training_step(self, batch, batch_idx):
~/anaconda3/envs/pytorch_latest_p36/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
<ipython-input-17-542f29e75b1a> in forward(self, sentence, labels)
67 loss = None
68 if labels:
---> 69 loss = F.cross_entropy(logits.view(-1, self.num_labels[0]), labels[0].view(-1))
70 return loss, logits
71
RuntimeError: shape '[-1, 38]' is invalid for input of size 1000
Solution 1:[1]
It sounds like you're trying to load a pretrained model (which uses an unidirectional LSTM) into a model which has a bidirectional LSTM in its state dict. There are several things you can do here, as there are innate differences between your pretrained state dict and your bidirectional state dict:
- Definitely use
model.load_state_dict(model_params,strict=False)(see this link). This will stop the complaining when you use a model that's different to what you're trying to learn. It means that your forward pass will be pretrained but not your backward pass. - If you do this ^ you will need to sum or otherwise condense the final time steps for the forward and backward case because the classifier will then have a different shape otherwise.
strict=Falsethough will ignore this, so only do this if you care about having a pretrained first layer in your classifier. - If you don't want to do the above two, you can copy the weights for
model.lstm.weight_ih_l0_reverseand other missing parameters from the forward direction in the state dict, as it's just a python dictionary. It is not ideal because obviously the forward and backward pass will learn different things, but will stop the error and be in a reasonably good initialisation space. You will still have the same error in two though where your LSTM output is twice as big as it was.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | jhso |
