'How to read weights saved in tensorflow checkpoint file?
I'd like to read the weights and visualize them as images. But I don't see any documentation about model format and how to read the trained weights.
Solution 1:[1]
There's this utility which has on print_tensors_in_checkpoint_file method http://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/inspect_checkpoint.py
Alternatively, you can use Saver to restore the model and use session.run on variable tensors to get values as numpy arrays
Solution 2:[2]
I wrote snippet in Python
def extracting(meta_dir):
num_tensor = 0
var_name = ['2-convolutional/kernel']
model_name = meta_dir
configfiles = [os.path.join(dirpath, f)
for dirpath, dirnames, files in os.walk(model_name)
for f in fnmatch.filter(files, '*.meta')] # List of META files
with tf.Session() as sess:
try:
# A MetaGraph contains both a TensorFlow GraphDef
# as well as associated metadata necessary
# for running computation in a graph when crossing a process boundary.
saver = tf.train.import_meta_graph(configfiles[0])
except:
print("Unexpected error:", sys.exc_info()[0])
else:
# It will get the latest check point in the directory
saver.restore(sess, configfiles[-1].split('.')[0]) # Specific spot
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
inside_list = [n.name for n in graph.as_graph_def().node]
print('Step: ', configfiles[-1])
print('Tensor:', var_name[0] + ':0')
w2 = graph.get_tensor_by_name(var_name[0] + ':0')
print('Tensor shape: ', w2.get_shape())
print('Tensor value: ', sess.run(w2))
w2_saved = sess.run(w2) # print out tensor
You could run it by giving meta_dir as your pre-trained model directory.
Solution 3:[3]
To expand on Yaroslav's answer, print_tensors_in_checkpoint_file is a thin wrapper around py_checkpoint_reader, which lets you concisely access the variables and retrieve the tensor in numpy format. For example, you have the following files in a folder called tf_weights:
checkpoint model.ckpt.data-00000-of-00001 model.ckpt.index model.ckpt.meta
Then you can use py_checkpoint_reader to interact with the weights without necessarily loading the entire model. To do that:
from tensorflow.python.training import py_checkpoint_reader
# Need to say "model.ckpt" instead of "model.ckpt.index" for tf v2
file_name = "./tf_weights/model.ckpt"
reader = py_checkpoint_reader.NewCheckpointReader(file_name)
# Load dictionaries var -> shape and var -> dtype
var_to_shape_map = reader.get_variable_to_shape_map()
var_to_dtype_map = reader.get_variable_to_dtype_map()
Now, the var_to_shape_map dictionary's keys matches the variables stored in your checkpoint. This means you can retrieve them with reader.get_tensor, e.g.:
ckpt_vars = list(var_to_shape_map.keys())
reader.get_tensor(ckpt_vars[1])
To summarize all of above, you can use the following code to get a dictionary of numpy arrays:
from tensorflow.python.training import py_checkpoint_reader
file_name = "./tf_weights/model.ckpt"
reader = py_checkpoint_reader.NewCheckpointReader(file_name)
state_dict = {
v: reader.get_tensor(v) for v in reader.get_variable_to_shape_map()
}
Solution 4:[4]
For tensorflow 2.4 and when using tf.train.Checkpoint, I have the following files

To save
import tensorflow as tf
model = # tf.keras.Model
optimizer = # tf.keras.optimizer
model_path = # './models/{exp_name}/epoch_{num}'
ckpt_obj = tf.train.Checkpoint(optimizer=optimizer, model=model)
ckpt_obj.save(file_prefix=model_path)
To Load
import tensorflow as tf
model = # tf.keras.Model # need to initialize the model again
optimizer = # tf.keras.optimizer
model_path = # './models/{exp_name}/epoch_{num}'
ckpt_obj = tf.train.Checkpoint(optimizer=optimizer, model=model)
ckpt_obj.restore(save_path=tf.train.latest_checkpoint(str(model_path))).assert_consumed()
To check individual weights
import tensorflow as tf
from tensorflow.python.training import py_checkpoint_reader
model_path = # './models/{exp_name}/epoch_{num}'
model_path = tf.train.latest_checkpoint(str(model_path))
reader = py_checkpoint_reader.NewCheckpointReader(path_reader)
dtype_map = reader.get_variable_to_dtype_map()
shape_map = reader.get_variable_to_shape_map()
state_dict = { v: reader.get_tensor(v) for v in shape_map}
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Yaroslav Bulatov |
| Solution 2 | |
| Solution 3 | xhlulu |
| Solution 4 | pmod |
