'How to access tensor_content values in TensorProto in TensorFlow?
Similar to How to access values in protos in TensorFlow? but doesn't cater for this case.
I see a bytes tensor_content attribute in TensorProto. I'm trying to get information about the nodes through:
for node in tf.get_default_graph().as_graph_def().node:
node.attr['value'].tensor.tensor_content # decode these bytes
For information, the print of a node looks something like this:
name: "conv2d/convolution/Shape"
op: "Const"
device: "/device:GPU:0"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
dim {
size: 4
}
}
tensor_content: "\003\000\000\000\003\000\000\000\001\000\000\000 \000\000\000"
}
}
}
Solution 1:[1]
from tensorflow.python.framework import tensor_util
for n in tf.get_default_graph().as_graph_def().node:
print tensor_util.MakeNdarray(n.attr['value'].tensor)
Solution 2:[2]
Decode tensor_array bytes and then reshape with given shape:
for node in tf.get_default_graph.as_graph_def().node:
tensor_bytes = node.attr["value"].tensor.tensor_content
tensor_dtype = node.attr["value"].tensor.dtype
tensor_shape = [x.size for x in node.attr["value"].tensor.tensor_shape.dim]
tensor_array = tf.decode_raw(tensor_bytes, tensor_dtype)
tensor_array = tf.reshape(tensor_array, tensor_shape)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | jkschin |
| Solution 2 | zong fan |
