'Tensor shape changes in tf.while_loop()
# initialize start token
target = tf.constant([[2]], dtype=tf.int32, shape=[1, 1]) # 2 - <BOS>
dummy_outputs = [4, 7, 1, 9, 15, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] # for testing
self.max_length = 10
def generate_next_token(tar):
# out, _ = self.model([features, tf.expand_dims(tar, 0)], training=False) # shape: [1, seq_length, vocab_size]
# out = out[:, -1, :] # take last token of sequence. shape: [1, vocab_size]
# out = tf.random.categorical(out, 1) # shape: [1, 1]
# out = tf.cast(out, tf.int32)
out = tf.constant([[dummy_outputs.pop(0)]], dtype=tf.int32, shape=[1, 1])
print(tar.shape)
tar = tf.concat([tar, out], axis=-1)
return tar
def end_of_sequence_not_reached(tar):
print(tar.shape)
return tf.math.logical_and(tf.less(tf.shape(tar)[-1], self.max_length),
tf.not_equal(tar[-1], 3)) # 3 - <EOS>
target = tf.while_loop(cond=end_of_sequence_not_reached, body=generate_next_token, loop_vars=[target], shape_invariants=[tf.TensorShape([1, None])])
Somehow, the shape of tar changes from (1, n) to (n) after every iteration of the tf.while_loop, and dimensions of length 1 are lost, so I have to work around it by doing
if len(tar.shape) < len(prev_shape):
tar = tf.expand_dims(tar, 0)
Why is this and how can I prevent it?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
