'Tensor shape seems to disappear when indexing result of tf.shape(tensor)
When I try to index the result of tf.shape(tensor) where tensor is some tensor, the result seems to turn into None unexpectedly. For example, I ran this code:
>>> from ray.rllib.models.utils import try_import_tf
>>> tf1, tf, tfv = try_import_tf()
>>> tf.compat.v1.enable_eager_execution()
>>> inp = tf.keras.layers.Input(shape=([19, 33, 1]), name='input')
>>> tf.shape(inp)
<KerasTensor: shape=(4,) dtype=int32 inferred_value=[None, 19, 33, 1] (created by layer 'tf.compat.v1.shape')>
And the results are as expected. However, when I try to run the following code next:
>>> tf.shape(inp)[0]
<KerasTensor: shape=() dtype=int32 inferred_value=[None] (created by layer 'tf.__operators__.getitem')>
>>> tf.shape(inp)[1]
<KerasTensor: shape=() dtype=int32 inferred_value=[None] (created by layer 'tf.__operators__.getitem_1')>
>>> tf.shape(inp)[2]
<KerasTensor: shape=() dtype=int32 inferred_value=[None] (created by layer 'tf.__operators__.getitem_2')>
>>> tf.shape(inp)[3]
<KerasTensor: shape=() dtype=int32 inferred_value=[None] (created by layer 'tf.__operators__.getitem_3')>
The inferred values are all None. What's going on here? Is this expected behaviour?
Solution 1:[1]
Given code works with Tensorflow 2.8.0
import tensorflow as tf
print(tf.__version__)
inp = tf.keras.layers.Input(shape=([19, 33, 1]), name='input')
tf.shape(inp)
2.8.0
<KerasTensor: shape=(4,) dtype=int32 inferred_value=[None, 19, 33, 1] (created by layer 'tf.compat.v1.shape_5')>
>>tf.shape(inp)[0]
<KerasTensor: shape=() dtype=int32 inferred_value=[None] (created by layer 'tf.__operators__.getitem')>
>>tf.shape(inp)[1]
<KerasTensor: shape=() dtype=int32 inferred_value=[19] (created by layer 'tf.__operators__.getitem_1')>
>>tf.shape(inp)[2]
<KerasTensor: shape=() dtype=int32 inferred_value=[33] (created by layer 'tf.__operators__.getitem_2')>
>>tf.shape(inp)[3]
<KerasTensor: shape=() dtype=int32 inferred_value=[1] (created by layer 'tf.__operators__.getitem_3')>
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | TFer |
