'Keras semantic segmentation, infinite epoch using ImageDataGenerators
I am trying to train a model based on the U-Net architecture. I am using two data generators (one for training, the other one for validation). However, whatever values I use for batch_size
, steps_per_epoch
etc, an epoch never ends.
The directory structure is:
-- images
-- train
-- img
-- img1.jpg ...
-- mask
-- mask1.jpg ...
-- val
-- img
-- img1.jpg ...
-- mask
-- mask1.jpg ...
Creating datasets:
train_dir = "./images/train"
val_dir = "./images/val"
seed = 42
train_datagen_params = dict(rotation_range=90,
zoom_range=0.2,
horizontal_flip=True,
width_shift_range=0.2,
height_shift_range=0.2,
rescale=1./255
)
train_datagen_images = ImageDataGenerator(**train_datagen_params)
train_datagen_masks = ImageDataGenerator(**train_datagen_params)
train_images = train_datagen_images.flow_from_directory(directory=f"{train_dir}",
target_size=(512, 512),
batch_size=2,
class_mode=None,
classes=["img"],
seed=seed)
train_masks = train_datagen_masks.flow_from_directory(directory=f"{train_dir}",
target_size=(512, 512),
batch_size=2,
class_mode=None,
color_mode="grayscale",
classes=["mask"],
seed=seed)
... the same for the validation generators
def combine_generator(gen1, gen2):
while True:
yield (next(gen1), next(gen2))
train_set = combine_generator(train_images, train_masks)
val_set = combine_generator(val_images, val_masks)
It seems to read the data properly:
Found 21 images belonging to 1 classes.
Found 21 images belonging to 1 classes.
Found 7 images belonging to 1 classes.
Found 7 images belonging to 1 classes.
Training:
model.fit(train_set,
validation_data=val_set,
epochs=10,
batch_size=2,
steps_per_epoch=1, # tried with 21//2, didn't work either
validation_batch_size=2,
validation_steps=1,
callbacks=callbacks
)
I realize the problem probably lies in the steps number by I can't figure out how to make it work properly.
Solution 1:[1]
I found the solution. In the docs it is said:
Do not specify the batch_size if your data is in the form of datasets, generators, or keras.utils.Sequence instances (since they generate batches).
Removing the batch_size
and validation_batch
parameters from the fit method solved the problem.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Ethan |