'Tensor shape changes in tf.while_loop()

    # initialize start token
    target = tf.constant([[2]], dtype=tf.int32, shape=[1, 1])     # 2 - <BOS>

    dummy_outputs = [4, 7, 1, 9, 15, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]  # for testing
    self.max_length = 10
    def generate_next_token(tar):
        # out, _ = self.model([features, tf.expand_dims(tar, 0)], training=False) # shape: [1, seq_length, vocab_size]
        # out = out[:, -1, :]  # take last token of sequence. shape: [1, vocab_size]
        # out = tf.random.categorical(out, 1)  # shape: [1, 1]
        # out = tf.cast(out, tf.int32)
        out = tf.constant([[dummy_outputs.pop(0)]], dtype=tf.int32, shape=[1, 1])
        print(tar.shape)
        tar = tf.concat([tar, out], axis=-1)
        return tar

    def end_of_sequence_not_reached(tar):
        print(tar.shape)
        return tf.math.logical_and(tf.less(tf.shape(tar)[-1], self.max_length),
                                   tf.not_equal(tar[-1], 3)) # 3 - <EOS>

    target = tf.while_loop(cond=end_of_sequence_not_reached, body=generate_next_token, loop_vars=[target], shape_invariants=[tf.TensorShape([1, None])])

Somehow, the shape of tar changes from (1, n) to (n) after every iteration of the tf.while_loop, and dimensions of length 1 are lost, so I have to work around it by doing

    if len(tar.shape) < len(prev_shape):
        tar = tf.expand_dims(tar, 0)

Why is this and how can I prevent it?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source