'Multiple Keras.Sequence.on_epoch_end() calls

Data for a Keras generator is being randomly fetched from a DB at the end of each epoch using on_epoch_end() and once during initialization. During training with multiple workers I've noticed multiple calls to the on_epoch_end() function at the end of each epoch instead of the expected single call.

Using traceback I've noticed the extra calls caused by multiple workers are all made by threading.py:

[custom_generator.py] on_epoch_end called
 File "/opt/conda/lib/python3.9/threading.py", line 930, in _bootstrap
   self._bootstrap_inner()
 File "/opt/conda/lib/python3.9/threading.py", line 973, in _bootstrap_inner
   self.run()
 File "/opt/conda/lib/python3.9/threading.py", line 910, in run
   self._target(*self._args, **self._kwargs)
 File "/opt/conda/lib/python3.9/site-packages/keras/utils/data_utils.py", line 761, in _run
   self.sequence.on_epoch_end()
 File "/home/user/custom_generator.py", line 110, in on_epoch_end
   traceback.print_stack()

Generator

class CustomGenerator(tf.keras.utils.Sequence):
   def __init__(self, sampler, processor, batch_size, steps_per_epoch):
     # Input size: (512,512,3) RGB Images
     self.samples = sampler.fetch(batch_size*steps_per_epoch)
     self.batch_size = batch_size
     self.steps_per_epoch = steps_per_epoch
   
   def __len__(self):
     return self.steps_per_epoch

   def __getitem__(self, index):
     return processor(
       self.samples[index * self.batch_size:(index+1) * self.batch_size]
     )

   def on_epoch_end(self):
     self.samples = sampler.fetch(self.batch_size*self.steps_per_epoch)
     traceback.print_stack()

gen = CustomGenerator(Sampler(), Processor(), batch_size=1, steps_per_epoch=1500)

Im using Tensorflow 2.8.0, same issue happened at TF 2.3.0 as well. Is there any way to reduce the number of calls to one without reducing the number of workers? Is this behaviour documented somewhere? Strangely, I am unable to reproduce the issue with dummy data.

EDIT I've managed to reproduce the problem. It happens only when both training and validation generators are supplied to the model.fit(). The validation generator has the appropriate number of calls made to the on_epoch_end() function. The training generator, however, has at least twice the expected calls.

Reproducible code

class Generator(tf.keras.utils.Sequence):
    
   def __init__(self, steps_per_epoch, batch_size):
       self.oee_counter = 0
       self.steps_per_epoch = steps_per_epoch
       self.batch_size = batch_size
    
   def __getitem__(self, idx):
       self.X = np.random.random((self.batch_size,512))
       self.y = np.random.randint(low=0, high=2, size=self.batch_size)
       return self.X, self.y

   def on_epoch_end(self):
       self.oee_counter += 1
       print('Called OEE')
    
   def __len__(self):
       return self.steps_per_epoch
    
gen1 = Generator(1500,1)
gen2 = Generator(2000,1)

inp = tf.keras.layers.Input((512))
out = tf.keras.layers.Dense(1,activation='sigmoid')(inp)
m = tf.keras.Model(inp, out)
m.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.BinaryCrossentropy())

m.fit(
    x=gen1,
    validation_data=gen2,
    validation_freq=1,
    epochs=3,
    workers=10,
    max_queue_size=50,
    use_multiprocessing=False,
)

print(gen1.oee_counter) # 6 calls (expected 3)
print(gen2.oee_counter) # 3 calls (correct)


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source