'TypeError: load() missing 1 required positional argument: 'sess' when loading model from TF2 Object-Detection-API Tutorial SavedModel Example

I am following the Tensor Flow 2 Object Detection API tutorial on read the docs io and am using the Object Detection from TF2 Saved Model example. https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/auto_examples/plot_object_detection_saved_model.html

I successfully downloaded the model using code from the tutorial:

# Download and extract model
def download_model(model_name, model_date):
    base_url = 'http://download.tensorflow.org/models/object_detection/tf2/'
    model_file = model_name + '.tar.gz'
    model_dir = tf.keras.utils.get_file(fname=model_name,
                                        origin=base_url + model_date + '/' + model_file,
                                        untar=True)
    return str(model_dir)

MODEL_DATE = '20200711'
MODEL_NAME = 'centernet_hg104_1024x1024_coco17_tpu-32'
PATH_TO_MODEL_DIR = download_model(MODEL_NAME, MODEL_DATE)

But when I run the load script from the tutorial:

import time
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils

PATH_TO_SAVED_MODEL = PATH_TO_MODEL_DIR + "/saved_model"

print('Loading model...', end='')
start_time = time.time()

# Load saved model and build the detection function
detect_fn = tf.saved_model.load(export_dir=PATH_TO_SAVED_MODEL, tags=None, options=None)

end_time = time.time()
elapsed_time = end_time - start_time
print('Done! Took {} seconds'.format(elapsed_time))

I get the error TypeError: load() missing 1 required positional argument: 'sess' Originally, my only argument for this function was 'PATH_TO_SAVED_MODEL', but 'export_dir=','tags=None' and 'options=None' were added after similar error messages prompted me to do so. With sess, I have tried adding 'sess', 'sess=None', and researching the tf docs for details on the sess argument in the load() function but have not had any luck.

Also, I was wondering if maybe there was an issue with how I downloaded the model. The tf docs always use the saved_model.save() function before using the load function, and I assumed that my script downloading the model did the same thing, but do I still need to save the downloaded model before loading it?

If anyone has any suggestions on how to successfully load the model, I would greatly appreciate it!



Solution 1:[1]

Try to use the `load_model() function:

model = tf.keras.models.load_model('<path_to_model>')

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Tillmann