'Not able to use the model.fit() in Keras after much work
I have been trying to develop a CNN Model to dictect and predict the Flower categories. I got stuck at the model.fit(). Here is my code. The datasets I used are
https://drive.google.com/drive/folders/1-QOrDBpVvXWb_zAsaxZnalmvUQRA7yOb?usp=sharing
Here is my code;
'''
model = Sequential()
model.add(Conv2D(32, (3,3), activation='relu', input_shape=(Image_Width, Image_Height, Image_Channels)))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, (3,3), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.25))
model.add(Conv2D(128, (3,3), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.25))
model.add(Conv2D(256, (3,3), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.25))
model.add(Conv2D(512, (3,3), activation='relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.25))
# 1st Hidden Layer
model.add(Dense(500, activation='sigmoid'))
model.add(BatchNormalization())
model.add(Dropout(0.25))
# 2nd Hidden Layer
model.add(Dense(100, activation='softmax'))
model.add(BatchNormalization())
model.add(Dropout(0.25))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
model.summary()
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
earlystop = EarlyStopping(patience = 10)
learning_rate_reduction = ReduceLROnPlateau(monitor = 'val_acc',patience = 2,verbose = 1,factor = 0.5,min_lr = 0.00001)
callbacks = [earlystop,learning_rate_reduction]
df_full_flower_dataset['category'] = df_full_flower_dataset['category'].replace({1:'daisy',2:'dandelion',3:'rose',4:'sunflower',5:'tulip'})
train_df,validate_df = train_test_split(df_full_flower_dataset,test_size=0.20,random_state=42)
train_df = train_df.reset_index(drop=True)
validate_df = validate_df.reset_index(drop=True)
total_train=train_df.shape[0]
total_validate=validate_df.shape[0]
batch_size=15
train_datagen = ImageDataGenerator(rotation_range=15,
rescale=1./255,
shear_range=0.1,
zoom_range=0.2,
horizontal_flip=True,
width_shift_range=0.1,
height_shift_range=0.1
)
train_generator = train_datagen.flow_from_dataframe(train_df,
x_col='filename',y_col='category',
target_size=Image_Size,
class_mode='categorical',
batch_size=batch_size)
validation_datagen = ImageDataGenerator(rescale=1./255)
validation_generator = validation_datagen.flow_from_dataframe(
validate_df,
x_col='filename',
y_col='category',
target_size=Image_Size,
class_mode='categorical',
batch_size=batch_size
)
test_datagen = ImageDataGenerator(rotation_range=15,
rescale=1./255,
shear_range=0.1,
zoom_range=0.2,
horizontal_flip=True,
width_shift_range=0.1,
height_shift_range=0.1)
test_generator = train_datagen.flow_from_dataframe(train_df,
x_col='filename',y_col='category',
target_size=Image_Size,
class_mode='categorical',
batch_size=batch_size)
model.fit(
train_generator,
steps_per_epoch = 10,
epochs = 2,
validation_data = test_generator,verbose = 1,
validation_steps = 32)
Epoch 1/2
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-145-f3e92aad38c9> in <module>()
8 steps_per_epoch = 10,
9 epochs = 2,
---> 10 validation_data = test_generator)
1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
57 ctx.ensure_initialized()
58 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 59 inputs, attrs, num_outputs)
60 except core._NotOkStatusException as e:
61 if name is not None:
InvalidArgumentError: required broadcastable shapes
'''
This is the error I got... I am new to programming with python and deep learning. But it seems this model.fit() has me stuck. How can I solve this issue?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
