'Feeding multi-input .tfrecord-file to .fit()

I try to train my model using a tfrecord dataset (800 GB). The simplified data pipeline looks like this:

files = tf.io.matching_files(tfr_dir + '*_' + single_pattern + '_*')
shards = tf.data.Dataset.from_tensor_slices(files)
# Read the tfrecords
dataset = tf.data.TFRecordDataset(filenames=shards, num_parallel_reads=tf.data.experimental.AUTOTUNE)
# Parse the tfrecords
dataset = dataset.map(parse_tfr_element, num_parallel_calls=tf.data.experimental.AUTOTUNE)
# Apply image augmentation and parameter optimization using tf.py_function with defined Tout
dataset = self.dataset.map(imgaug)
# Sort out erroneous 
dataset = dataset.filter(lambda f1, f2, f3, f4, f5, state: state == False)
# Batch and prefetch data (not using shuffle atm)
dataset = dataset.batch(self.config.batch_size, num_parallel_calls=self.AUTOTUNE, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE) 

This gives me as output following (batch_size=8):

<RepeatDataset element_spec=
(TensorSpec(shape=(8, 4, 256, 256, 3), dtype=tf.float32, name=None), 
TensorSpec(shape=(8, 4, 19, 2), dtype=tf.float32, name=None), 
TensorSpec(shape=(8, 4, 19, 3), dtype=tf.float32, name=None), 
TensorSpec(shape=(8, 4, 3, 4), dtype=tf.float32, name=None),
TensorSpec(shape=(8, 4, 4, 3), dtype=tf.float32, name=None), 
TensorSpec(shape=(8,), dtype=tf.bool, name=None))>

dataset[0], dataset[3] and dataset[4] are the inputs (x) and dataset[1] and dataset[2] is the ground truth (y) (depending on the model).

This works well using a custom training loop iterating over the batches of the dataset using for step, data in enumerate(dataset) and defining the inputs to the model by simple subscribing e.g. data[0]. However, I can't get it running using .fit(). I tried different approaches to force .fit() to iterate over the dataset (next(iter(dataset), .from_generator()) but had no luck so far.

So how could I get a multi-input dataset into the fit function? I consider atm to not use tfrecords, as they were so far just hard to use.

Thanks for your help and all the best



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source