'How to read weights saved in tensorflow checkpoint file?

I'd like to read the weights and visualize them as images. But I don't see any documentation about model format and how to read the trained weights.



Solution 1:[1]

There's this utility which has on print_tensors_in_checkpoint_file method http://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/inspect_checkpoint.py

Alternatively, you can use Saver to restore the model and use session.run on variable tensors to get values as numpy arrays

Solution 2:[2]

I wrote snippet in Python

def extracting(meta_dir):
    num_tensor = 0
    var_name = ['2-convolutional/kernel']
    model_name = meta_dir
    configfiles = [os.path.join(dirpath, f)  
        for dirpath, dirnames, files in os.walk(model_name)
        for f in fnmatch.filter(files, '*.meta')] # List of META files

    with tf.Session() as sess:
        try:
            # A MetaGraph contains both a TensorFlow GraphDef
            # as well as associated metadata necessary
            # for running computation in a graph when crossing a process boundary.
            saver = tf.train.import_meta_graph(configfiles[0])
       except:
           print("Unexpected error:", sys.exc_info()[0])
       else:
           # It will get the latest check point in the directory
           saver.restore(sess, configfiles[-1].split('.')[0])  # Specific spot

           # Now, let's access and create placeholders variables and
           # create feed-dict to feed new data
           graph = tf.get_default_graph()
           inside_list = [n.name for n in graph.as_graph_def().node]

           print('Step: ', configfiles[-1])

           print('Tensor:', var_name[0] + ':0')
           w2 = graph.get_tensor_by_name(var_name[0] + ':0')
           print('Tensor shape: ', w2.get_shape())
           print('Tensor value: ', sess.run(w2))
           w2_saved = sess.run(w2)  # print out tensor

You could run it by giving meta_dir as your pre-trained model directory.

Solution 3:[3]

To expand on Yaroslav's answer, print_tensors_in_checkpoint_file is a thin wrapper around py_checkpoint_reader, which lets you concisely access the variables and retrieve the tensor in numpy format. For example, you have the following files in a folder called tf_weights:

checkpoint  model.ckpt.data-00000-of-00001  model.ckpt.index  model.ckpt.meta

Then you can use py_checkpoint_reader to interact with the weights without necessarily loading the entire model. To do that:

from tensorflow.python.training import py_checkpoint_reader

# Need to say "model.ckpt" instead of "model.ckpt.index" for tf v2
file_name = "./tf_weights/model.ckpt"
reader = py_checkpoint_reader.NewCheckpointReader(file_name)

# Load dictionaries var -> shape and var -> dtype
var_to_shape_map = reader.get_variable_to_shape_map()
var_to_dtype_map = reader.get_variable_to_dtype_map()

Now, the var_to_shape_map dictionary's keys matches the variables stored in your checkpoint. This means you can retrieve them with reader.get_tensor, e.g.:

ckpt_vars = list(var_to_shape_map.keys())
reader.get_tensor(ckpt_vars[1])

To summarize all of above, you can use the following code to get a dictionary of numpy arrays:

from tensorflow.python.training import py_checkpoint_reader

file_name = "./tf_weights/model.ckpt"
reader = py_checkpoint_reader.NewCheckpointReader(file_name)

state_dict = {
    v: reader.get_tensor(v) for v in reader.get_variable_to_shape_map()
}

Solution 4:[4]

For tensorflow 2.4 and when using tf.train.Checkpoint, I have the following files enter image description here

To save

import tensorflow as tf
model      = # tf.keras.Model
optimizer  = # tf.keras.optimizer
model_path = # './models/{exp_name}/epoch_{num}'

ckpt_obj   = tf.train.Checkpoint(optimizer=optimizer, model=model)
ckpt_obj.save(file_prefix=model_path)

To Load

import tensorflow as tf
model      = # tf.keras.Model # need to initialize the model again
optimizer  = # tf.keras.optimizer
model_path = # './models/{exp_name}/epoch_{num}'

ckpt_obj   = tf.train.Checkpoint(optimizer=optimizer, model=model)
ckpt_obj.restore(save_path=tf.train.latest_checkpoint(str(model_path))).assert_consumed()

To check individual weights

import tensorflow as tf
from tensorflow.python.training import py_checkpoint_reader

model_path = # './models/{exp_name}/epoch_{num}'
model_path = tf.train.latest_checkpoint(str(model_path))
reader     = py_checkpoint_reader.NewCheckpointReader(path_reader)

dtype_map  = reader.get_variable_to_dtype_map()
shape_map  = reader.get_variable_to_shape_map()

state_dict = { v: reader.get_tensor(v) for v in shape_map}

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Yaroslav Bulatov
Solution 2
Solution 3 xhlulu
Solution 4 pmod