'How to know the trained model is correct?
I use PyTorch Lightning for model training, during which I use ModelCheckpoint to save loading points. Finally, I would like to know whether the model is loaded correctly. Let me know if you require further information?
checkpoint_callback = ModelCheckpoint(
filename='tb1000_{epoch: 02d}-{step}',
monitor='val/acc@1',
save_top_k=5,
mode='max')
wandb_logger = pl.loggers.wandb.WandbLogger(
name=run_name,
project=args.project,
entity=args.entity,
offline=args.offline,
log_model='all')
model = BYOL(**args.__dict__, num_classes=dm.num_classes)
trainer = pl.Trainer.from_argparse_args(args,
logger=wandb_logger, callbacks=[checkpoint_callback])
trainer.fit(model, dm)
# Loading and testing
model_test = BYOL(**args.__dict__, num_classes=dm.num_classes)
path = "/tb100_epoch= 819-step=39359.ckpt"
model_test.load_from_checkpoint(path)
Solution 1:[1]
load_from_checkpoint() will return a model with trained weights, so you need to assign it to a new variable.
model_test = model_test.load_from_checkpoint(path)
or
model_test = BYOL.load_from_checkpoint(path)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | joe32140 |
