'Python Tensorflow Encoder - Sequence-to-Vector - Training - @tf.function(input_signature)

Dear Internet I am building an encoder model (just the encoding block of a transformer: Seq2Vec) following Tensorflow's Transformer Tutorial. I've built an encoder that works fine by itself (takes input array and spits out a tuple with 1: an array of batch and output size and 2: a dict of attention weights per encoder layer.

>>> encoder_model = EncoderModel(**model_params)
>>> x = tf.random.uniform(shape=(256, 40), minval=0, maxval=200, dtype=tf.int32)
>>> print(encoder_model(x, training=False)[0].shape)
(256, 16) # works perfectly :-)

I'm really struggling to figure out how to train the model. Following the tutorial, I am using the scheme below:

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')

train_step_signature = [tf.TensorSpec(shape=(None, None), dtype=tf.int64),
                        tf.TensorSpec(shape=(None, None), dtype=tf.int64)]

@tf.function(input_signature=train_step_signature)
def train_step(input, target):
    with tf.GradientTape() as tape:
        predictions, _ = encoder_model(input, training = True) # <-- code breaks here
        loss = loss_function(target, predictions)
    gradients = tape.gradient(loss, encoder_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, encoder_model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(accuracy_function(target, predictions))

for epoch in range(EPOCHS):
    train_loss.reset_states()
    train_accuracy.reset_states()
    for i, (inp, tar) in enumerate(train_ds):
        train_step(inp, tar)

However, because the EncoderModel uses two Dense layers to convert the MultiHeadAttention layer into the output size, I am getting the following error:

ValueError: The last dimension of the inputs to a Dense layer should be defined. Found None. Full input shape received: (None, None)

I think this has to do with the @tf.function(input_signature=train_step_signature) line, but I'm not sure how to get around it. Any tips? Thanks a mill :)



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source