'XLA compilation requires a fixed tensor list size. Error on Colab TPU

I am trying to run this model on colab TPU, but getting this error can anyone please help. Thank you.

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
# This is the TPU initialization code that has to be at the beginning.
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
strategy = tf.distribute.MirroredStrategy(tf.config.list_logical_devices('TPU'))
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
BATCH_SIZE_PER_REPLICA_TPU = 64
GLOBAL_BATCH_SIZE_TPU = BATCH_SIZE_PER_REPLICA_TPU * strategy.num_replicas_in_sync
val_dataset = tf.data.Dataset.from_tensor_slices((val_source_text[:10000], val_target_text[:10000]))
val_dataset = val_dataset.shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE_TPU)
val_dist_dataset = strategy.experimental_distribute_dataset(val_dataset)
class MaskedLoss(tf.keras.losses.Loss):
  def __init__(self):
    self.name = 'masked_loss'
    self.loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
  
  def __call__(self, y_true, y_pred):
    loss = self.loss(y_true, y_pred)
    mask = tf.cast(y_true !=0, tf.float32)
    loss*=mask
    return tf.nn.compute_average_loss(loss, global_batch_size=GLOBAL_BATCH_SIZE_TPU)

class BatchLogs(tf.keras.callbacks.Callback):
  def __init__(self, key):
    self.key = key
    self.logs = []

  def on_train_batch_end(self, n, logs):
    self.logs.append(logs[self.key])

class CoverageLoss(tf.keras.losses.Loss):
  def __init__(self):
    self.name = "coverage_loss"
  
  def __call__(self, coverage_vector, attention_weights):
    return tf.nn.compute_average_loss(tf.math.minimum(coverage_vector, attention_weights), global_batch_size=GLOBAL_BATCH_SIZE_TPU)
class AttentionEncoderDecoder(tf.keras.Model):
  def __init__(self, units, embedding_dims, input_preprocessor, output_preprocessor, batch_size, use_tf_function=False, LAMBDA = 1):
    super(AttentionEncoderDecoder, self).__init__()
    self.units = units
    self.LAMBDA = LAMBDA
    self.BATCH_SIZE = batch_size
    self.embedding = embedding_dims
    self.use_tf_function = use_tf_function
    self.input_preprocessor = input_preprocessor
    self.output_preprocessor = output_preprocessor
    self.MAX_TARGET_VOCAB = self.output_preprocessor.vocabulary_size()
    self.MAX_SOURCE_VOCAB = self.input_preprocessor.vocabulary_size()
    self.encoder = Encoder(self.units, self.embedding, self.MAX_SOURCE_VOCAB)
    self.decoder = Decoder(self.units, self.embedding, self.MAX_TARGET_VOCAB)


  def overflow_tokens_to_rel_input_pos(self, inputs):
    out, inp = inputs
    return tf.where(out >= self.MAX_TARGET_VOCAB, self.MAX_TARGET_VOCAB + tf.argmax(inp == tf.expand_dims(out, axis= 1), axis=1), out)


  def _preprocess(self, input_text, output_text):
    input_tokens = self.input_preprocessor(input_text)
    target_tokens = self.output_preprocessor(output_text)
    target_tokens_on_input_vocab = self.input_preprocessor(output_text)
    target_tokens_mapped = tf.where(target_tokens == 1, target_tokens_on_input_vocab, target_tokens)
    target_tokens_mapped = tf.vectorized_map(self.overflow_tokens_to_rel_input_pos, (target_tokens_mapped, input_tokens))
    input_mask = (input_tokens != 0)
    target_mask = (target_tokens != 0)

    return input_tokens, input_mask, target_tokens, target_tokens_mapped, target_mask


  @tf.function()
  def distributed_train_step(self, dataset_inputs):
    per_replica_losses = strategy.run(self.train_step, args=(dataset_inputs,))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                          axis=None)

  def train_step(self, inputs):
    if self.use_tf_function:
      return self._tf_train_step(inputs)
    else:
      return self._train_step(inputs)

  def _tf_train_step(self, inputs):
      return self._train_step(inputs)


  def _one_step(self, input_token, target_token, enc_output, input_mask, dec_state, coverage_vector, training = True):
    decoder_input = DecoderInput(tf.cast(input_token, tf.float32), enc_output, input_mask)
    dec_result, dec_state, coverage_vector = self.decoder(decoder_input, coverage_vector, state=dec_state)

    y_true = target_token #(batch, 1)
    y_pred = dec_result.logits #(batch, t_step, outvocab + encoder_output_steps)
    step_loss = self.loss['MaskedLoss'](y_true, y_pred) + self.LAMBDA * self.loss['CoverageLoss'](coverage_vector, dec_result.attention_weights)

    return dec_result, dec_state, coverage_vector, step_loss

  def _train_step(self, inputs):
    input_text, target_text = inputs
    input_tokens, input_mask, target_tokens, target_tokens_mapped, target_mask = self._preprocess(input_text, target_text)
    coverage_vector = tf.zeros(shape = (self.BATCH_SIZE, 1, tf.shape(input_tokens)[1]))
    max_target_sen_length = tf.shape(target_tokens)[1]
    with tf.GradientTape() as tape:
      enc_output, enc_state = self.encoder(input_tokens)
      dec_state = enc_state
      loss = tf.constant(0.0)

      for t_step_dec in tf.range(max_target_sen_length-1):
        input_token = tf.slice(target_tokens, [0, t_step_dec], [-1, 1])
        target_token = tf.slice(target_tokens, [0, t_step_dec+1], [-1, 1])
        dec_result, dec_state, coverage_vector, step_loss = self._one_step(input_token, target_token, enc_output, input_mask, dec_state, coverage_vector, True)
        loss = loss + step_loss

      avg_loss = loss/tf.reduce_sum(tf.cast(target_mask, tf.float32))

    variables = self.trainable_variables
    gradients = tape.gradient(avg_loss, variables)
    self.optimizer.apply_gradients(zip(gradients, variables))
    return avg_loss
with strategy.scope():
  tboard_callback = tf.keras.callbacks.TensorBoard(log_dir = logs)
  text_summarizer = AttentionEncoderDecoder(UNITS, EMBEDDING, input_preprocessing, output_preprocessing, GLOBAL_BATCH_SIZE_TPU, True, 1)
  text_summarizer.compile(optimizer = "Adam", loss = {"MaskedLoss": MaskedLoss(), "CoverageLoss" : CoverageLoss()}, steps_per_execution = 50)

  for epoch in range(EPOCHS):
    # TRAIN LOOP
    total_loss = 0.0
    num_batches = 0
    for x in val_dist_dataset:
      total_loss += text_summarizer.distributed_train_step(x)
      num_batches += 1

    if(num_batches == 0):
      num_batches = 1
    train_loss = total_loss / num_batches
    print(train_loss)

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-102-82fbe4fc7718> in <module>()
      9     num_batches = 0
     10     for x in val_dist_dataset:
---> 11       total_loss += text_summarizer.distributed_train_step(x)
     12       num_batches += 1
     13 

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
    154     finally:
    155       del filtered_tb

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   7184 def raise_from_not_ok_status(e, name):
   7185   e.message += (" name: " + name if name is not None else "")
-> 7186   raise core._status_to_exception(e) from None  # pylint: disable=protected-access
   7187 
   7188 

InvalidArgumentError: 4 root error(s) found. (0) INVALID_ARGUMENT: {{function_node __inference_distributed_train_step_284140}} XLA compilation requires a fixed tensor list size. Set the max number of elements. This could also happen if you're using a TensorArray in a while loop that does not have its maximum_iteration set, you can fix this by setting maximum_iteration to a suitable value.

 [[{{node gradient_tape/replica_6/while/replica_6/while/decoder_2/gru_5/PartitionedCall_11/accumulator}}]]
 [[gradient_tape/replica_6/while/replica_6/while/decoder_8/gru_17/PartitionedCall_11/accumulator]]
 [[replica_3_while_input_16/_387]]
 [[var_shape_replica_5_while_input_17/_17932/_1287]]

(1) INVALID_ARGUMENT: {{function_node __inference_distributed_train_step_284140}} XLA compilation requires a fixed tensor list size. Set the max number of elements. This could also happen if you're using a TensorArray in a while loop that does not have its maximum_iteration set, you can fix this by setting maximum_iteration to a suitable value.

 [[{{node gradient_tape/replica_6/while/replica_6/while/decoder_2/gru_5/PartitionedCall_11/accumulator}}]]
 [[gradient_tape/replica_6/while/replica_6/while/decoder_8/gru_17/PartitionedCall_11/accumulator]]
 [[replica_3_while_input_16/_387]]

(2) INVALID_ARGUMENT: {{function_node __inference_distributed_train_step_284140}} XLA compilation requires a fixed tensor list size. Set the max number of elements. This could also happen if you're using a TensorArray in a while loop that does not have its maximum_iteration set, you can fix this by setting maximum_iteration to a suitable value.

 [[{{node gradient_tape/replica_6/while/replica_6/while/decoder_2/gru_5/PartitionedCall_11/accumulator}}]]
 [[gradient_tape/replica_6/while/replica_6/while/decoder_8/gru_17/PartitionedCall_11/accumulator]]

(3) CANCELLED: {{function_node __inference_distributed_train_step_284140}} Function was cancelled before it was started 0 successful operations. 0 derived errors ignored. [Op:AddV2]



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source