'PyTorch Lightning model checkpoint loading for Barlow Twins tutorial
I am working through the Barlow Twins tutorial with PyTorch Lightning and I am having trouble loading the encoder portion of the model using the checkpoint after training.
During model training, checkpoints are saved with ModelCheckpoint. In the tutorial, the author offers two options for then getting the encoder portion of the model with trained weights: 1) calling model.encoder (model has to have been trained in the active kernel for this to work) or 2) loading the trained model with:
ckpt_model = torch.load('[checkpoint name].ckpt')
And then calling
encoder = ckpt_model.encoder
I would like to be able to load the model/encoder from a saved checkpoint but when I try to do this, I get the error: AttributeError: 'dict' object has no attribute 'encoder'
This seems to make sense to me because the model is being loaded with a simple Torch loader rather than the lightning loader.
When I print the contents of ckpt_model using ckpt_model.keys() I get: dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers']). The state_dict contains weights and biases for the encoder layer.
Then when I try loading the model with the PyTorch Lightning loader:
BarlowTwins.load_from_checkpoint('[checkpoint name].ckpt')
I get the error: TypeError: init() missing 4 required positional arguments: 'encoder', 'encoder_out_dim', 'num_training_samples', and 'batch_size'. I might be able to save those in the checkpoint with save_hyperparameters.
My question: how can I load the encoder portion of the model in the simple way possible? I only need the encoder portion of the model loaded.
Thanks so much for help in advance!
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
