'How to continue training with checkpoints using object_detector.EfficientDetLite4Spec tensorflow lite

Preciously I have set my EfficientDetLite4 model "grad_checkpoint=true" in config.yaml. And it had successfully generated some checkpoints. However, I can't figure out how to use these checkpoints when I want to continue training based on them.

Every time I train the model it just start from the beginning, not from my checkpoints.

The following picture shows my colab file system structure:

<img src="https://i.stack.imgur.com/8EhPx.jpg"/>

my colab file system structure

The following picture shows where my checkpoints store:

<img src="https://i.stack.imgur.com/Ve5al.jpg"/>

model file system here

The following code shows how I configure the model and how I train with the model.

import numpy as np
import os

from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

train_data, validation_data, test_data = 
    object_detector.DataLoader.from_csv('csv_path')

spec = object_detector.EfficientDetLite4Spec(
    uri='/content/model',
    model_dir='/content/drive/MyDrive/MathSymbolRecognition/CheckPoints/',
    hparams='grad_checkpoint=true,strategy=gpus',
    epochs=50, batch_size=3,
    steps_per_execution=1, moving_average_decay=0,
    var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
    tflite_max_detections=25, strategy=spec_strategy
)

model = object_detector.create(train_data, model_spec=spec, batch_size=3, 
    train_whole_model=True, validation_data=validation_data)


Solution 1:[1]

The source code is the answer !

I ran into the same problem and found out that the model_dir we pass to the TFLite model Maker's object detector API is only used for saving the model's weights: that's why the API never restores from checkpoints.

Having a look at the source code of this API, I noticed it internally uses the standard model.compile and model.fit functions and it saves the model's weights through the callbacks parameter of model.fit.
This means that, provided that we can get the interal keras model, we can just restore our checkpoints by using model.load_weights !

These are the links to the source code if you want to know more about what some of the functions I use below do:

This is the code:

#Useful imports
import tensorflow as tf
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
from tflite_model_maker.object_detector import DataLoader

#Import the same libs that TFLiteModelMaker interally uses
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib

#Setup variables
batch_size = 6 #or whatever batch size you want
epochs = 50
checkpoint_dir = "/content/..." #whatever your checkpoint directory is

#Create whichever object detector's spec you want
spec = object_detector.EfficientDetLite4Spec(
    model_name='efficientdet-lite4',
    uri='https://tfhub.dev/tensorflow/efficientdet/lite4/feature-vector/2', 
    hparams='', #enable grad_checkpoint=True if you want
    model_dir=checkpoint_dir, 
    epochs=epochs, 
    batch_size=batch_size,
    steps_per_execution=1, 
    moving_average_decay=0,
    var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
    tflite_max_detections=25, 
    strategy=None, 
    tpu=None, 
    gcp_project=None,
    tpu_zone=None, 
    use_xla=False, 
    profile=False, 
    debug=False, 
    tf_random_seed=111111,
    verbose=1
)

#Load you datasets
train_data, validation_data, test_data = object_detector.DataLoader.from_csv('/path/to/csv.csv')

#Create the object detector 
detector = object_detector.create(train_data, 
                                model_spec=spec, 
                                batch_size=batch_size, 
                                train_whole_model=True, 
                                validation_data=validation_data,
                                epochs = epochs,
                                do_train = False
                                )
"""
From here on we use internal/"private" functions of the API,
you can tell because the methods's names begin with an underscore
"""

#Convert the datasets for training
train_ds, steps_per_epoch, _ = detector._get_dataset_and_steps(train_data, batch_size, is_training = True)
validation_ds, validation_steps, val_json_file = detector._get_dataset_and_steps(validation_data, batch_size, is_training = False)

#Get the interal keras model    
model = detector.create_model()

#Copy what the API interally does as setup
config = spec.config
config.update(
    dict(
        steps_per_epoch=steps_per_epoch,
        eval_samples=batch_size * validation_steps,
        val_json_file=val_json_file,
        batch_size=batch_size
    )
)
train.setup_model(model, config) #This is the model.compile call basically
model.summary()

"""
Here we restore the weights
"""

#Load the weights from the latest checkpoint
try:
  latest = tf.train.latest_checkpoint(checkpoint_dir) #example: "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/ckpt-35"
  completed_epochs = int(latest.split("/")[-1].split("-")[1]) #the epoch the training was at when the training was last interrupted
  model.load_weights(latest)
  print("Checkpoint found {}".format(latest))
except Exception as e:
  print("Checkpoint not found: ", e)

#Train the model 
model.fit(
    train_ds,
    epochs=epochs,
    initial_epoch=completed_epochs, 
    steps_per_epoch=steps_per_epoch,
    validation_data=validation_ds,
    validation_steps=validation_steps,
    callbacks=train_lib.get_callbacks(config.as_dict(), validation_ds) #This is for saving checkpoints at the end of every epoch
)

#Save/export the trained model
#Tip: for integer quantization you simply have to NOT SPECIFY 
#the quantization_config parameter of the detector.export method
export_dir = "/content/..." #save the tflite wherever you want
quant_config = QuantizationConfig.for_float16() #or whatever quantization you want
detector.model = model #inject our trained model into the object detector
detector.export(export_dir = export_dir, tflite_filename='model.tflite', quantization_config = quant_config)

Solution 2:[2]

Checkpoint not found: 'NoneType' object has no attribute 'split' while trying to restore weights

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 sanjay