'Keras generator hangs with multiprocessing and generates the same index multiple times without multiprocessing

I defined a custom generator extending keras.utils.Sequence that uses slices from an iterator to generate each batch (code below). When I run predict() with the generator and multiprocessing=True it hangs after the first batch. If I set multiprocessing=False then the first index (0) is generated twice, resulting in too many batches and an error in the last batch.

Note: I am not using the GPU for predict(), it is running on a server with 80 cores and 800GB RAM. In python3.7 with keras v2.6.0

Below is minimal code:

from keras.models import load_model
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" # CPU only
import tensorflow as tf
import numpy as np
import itertools
from tensorflow.keras.utils import Sequence


class DataGenerator(Sequence):
    def __init__(self, L, batch_size=4**7):
        self.L=L
        self.batch_size = batch_size
        self.it=itertools.product("abcd", repeat=L)

    def __len__(self):
        return int(np.floor(4**self.L / self.batch_size))

    def __getitem__(self,idx):
        print("Index: ", idx)
        lst = np.array(list(itertools.islice(self.it,self.batch_size)))
        return lst


config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=60,
                            inter_op_parallelism_threads=2,
                            allow_soft_placement=True,
                            device_count = {'CPU': 60})

session = tf.compat.v1.Session(config=config)

model = load_model('/path/to/model', compile=False)

gen = DataGenerator(15)

labels=model.predict(gen, workers=60, steps=int(np.floor(4**15/ 4**7)),verbose=1, use_multiprocessing=True)

As is, the code prints out Index: 0 and then hangs. If I remove multiprocessing=True then the Index 0 is printed out twice before the other indexes and then the iterator runs out of data on the last index causing an error, i.e. too many indexes are created because 0 occurs twice.

Note - I observed the same behaviour when I tried the deprecated predict_generator() instead of predict().

Any ideas how to solve this issue? I prefer a solution that works with multiprocessing=True if possible.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source