'tf.train.Checkpoint and loading weights
I'm training a model for seq2Seq using tensorflow. correct me if I'm wrong. I understood that the tf.train.Checkpoint is used to save just the checkpoint files which are only useful when source code that will use the saved parameter values is available. i would like to know how i could instatiate my model later on and load the trained weights from checkpoint in order to test it.
checkpoint_dir = 'training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
here is the code for training:
EPOCHS = 20
for epoch in range(EPOCHS):
start = time.time()
enc_hidden = encoder.initialize_hidden_state()
total_loss = 0
for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
batch_loss = train_step(inp, targ, enc_hidden)
total_loss += batch_loss
if batch % 100 == 0:
print('Epoch {} Batch {} loss {}'.format(epoch + 1,batch, batch_loss.numpy()))
# saving (checkpoint) the model every 2 epochs
if (epoch + 1) % 2 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
regards
Solution 1:[1]
Here is a proposed answer which suggests to use checkpoint manager.
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
manager = tf.train.CheckpointManager(checkpoint, './tf_ckpts', max_to_keep=3)
def train_and_checkpoint(net, manager)://Net is your custom model here and manager is managing checkpoints
checkpoint.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
EPOCHS = 20
for epoch in range(EPOCHS):
start = time.time()
enc_hidden = encoder.initialize_hidden_state()
total_loss = 0
for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
batch_loss = train_step(inp, targ, enc_hidden)
total_loss += batch_loss
if batch % 100 == 0:
print('Epoch {} Batch {} loss {}'.format(epoch + 1,batch, batch_loss.numpy()))
# saving (checkpoint) the model every 2 epochs
if (epoch + 1) % 2 == 0:
saved_path = manager.save()
print("Saved checkpoint for epoch {}: {}".format(int(epoch), save_path))
//Run the above function once to save the checkpoints once.
train_and_checkpoint(net, manager)
//Instantiate a new model and restore the weights , start training again from last checkpoint
opt = optimizer // the optimizer passed earlier
net = Net() // your custom model
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
encoder=encoder,
decoder=decoder)
manager = tf.train.CheckpointManager(checkpoint, './tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)//it will restore weights from last checkpoint and start training again
Ref - https://www.tensorflow.org/guide/checkpoint#train_and_checkpoint_the_model
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Tensorflow Support |
