'Does the Pytorch Lightning Trainer use the validation data to optimize the models weights?
I am currently working with Pytorch Forecasting, which heavily uses Pytorch Lightning. Here I am applying the Pytorch Lightning Trainer to train a Temporal Fusion Transformer Model, roughly following the outline of this example. My rough training code and model definition looks like this:
training = TimeSeriesDataSet(
df_train[lambda x: x.time_idx <= training_cutoff],
time_idx="time_idx",
target="target",
group_ids=["group"],
max_prediction_length=90,
min_encoder_length=365 // 2,
max_encoder_length=365,
time_varying_unknown_reals=["target"],
time_varying_known_reals=["time_idx"]
)
validation = TimeSeriesDataSet.from_dataset(training, df_train, predict=True, stop_randomization=True)
# create dataloaders for model
batch_size = 4
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=res.suggestion(),
hidden_size=16,
attention_head_size=1,
dropout=0.1,
hidden_continuous_size=8,
output_size=7,
loss=QuantileLoss(),
log_interval=10,
reduce_on_plateau_patience=4,
time_varying_reals_encoder=["target"],
time_varying_reals_decoder=["target"]
)
trainer = pl.Trainer(
max_epochs=15,
gpus=0,
weights_summary="top",
gradient_clip_val=0.1,
limit_train_batches=30,
callbacks=[lr_logger, early_stop_callback],
logger=logger,
)
trainer.fit(
tft,
train_dataloader,
val_dataloader
)
Now my question is, whether the validation data has any influence on the optimization of the model? I have been playing around with the max_prediction_length parameter and it seems to be the case that the model performs better when I set the validation time window to a larger time frame. Does the Pytorch Lightning Trainer use the validation data to optimize the model or am I missing something else?
Thanks a lot in advance!
Solution 1:[1]
Since PyTorch-forecasting is an abstraction built on PyTorch-lightning, we can refer to the documentation on the trainer abstraction of the latter framework (https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html).
# put model in train mode
model.train()
torch.set_grad_enabled(True)
losses = []
for batch in train_dataloader:
# calls hooks like this one
on_train_batch_start()
# train step
loss = training_step(batch)
# clear gradients
optimizer.zero_grad()
# backward
loss.backward()
# update parameters
optimizer.step()
losses.append(loss)
In the example above, we can see that the trainer only computes the loss of batches in the train_dataloader and propagates the losses back. It means that the validation set is not used for the update of the model's weights.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | jhonkola |
