'TypeError: load() missing 1 required positional argument: 'sess' when loading model from TF2 Object-Detection-API Tutorial SavedModel Example
I am following the Tensor Flow 2 Object Detection API tutorial on read the docs io and am using the Object Detection from TF2 Saved Model example. https://tensorflow-object-detection-api-tutorial.readthedocs.io/en/latest/auto_examples/plot_object_detection_saved_model.html
I successfully downloaded the model using code from the tutorial:
# Download and extract model
def download_model(model_name, model_date):
base_url = 'http://download.tensorflow.org/models/object_detection/tf2/'
model_file = model_name + '.tar.gz'
model_dir = tf.keras.utils.get_file(fname=model_name,
origin=base_url + model_date + '/' + model_file,
untar=True)
return str(model_dir)
MODEL_DATE = '20200711'
MODEL_NAME = 'centernet_hg104_1024x1024_coco17_tpu-32'
PATH_TO_MODEL_DIR = download_model(MODEL_NAME, MODEL_DATE)
But when I run the load script from the tutorial:
import time
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils
PATH_TO_SAVED_MODEL = PATH_TO_MODEL_DIR + "/saved_model"
print('Loading model...', end='')
start_time = time.time()
# Load saved model and build the detection function
detect_fn = tf.saved_model.load(export_dir=PATH_TO_SAVED_MODEL, tags=None, options=None)
end_time = time.time()
elapsed_time = end_time - start_time
print('Done! Took {} seconds'.format(elapsed_time))
I get the error TypeError: load() missing 1 required positional argument: 'sess' Originally, my only argument for this function was 'PATH_TO_SAVED_MODEL', but 'export_dir=','tags=None' and 'options=None' were added after similar error messages prompted me to do so. With sess, I have tried adding 'sess', 'sess=None', and researching the tf docs for details on the sess argument in the load() function but have not had any luck.
Also, I was wondering if maybe there was an issue with how I downloaded the model. The tf docs always use the saved_model.save() function before using the load function, and I assumed that my script downloading the model did the same thing, but do I still need to save the downloaded model before loading it?
If anyone has any suggestions on how to successfully load the model, I would greatly appreciate it!
Solution 1:[1]
Try to use the `load_model() function:
model = tf.keras.models.load_model('<path_to_model>')
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Tillmann |
