'Why training results are different with pred in keras50
I'm a newbie in Machine Learning. I want to build a keras model which will be used for image recognition. I am currently using the resnet50 model at:
PIL.Image.open(str(rock[0]))
img_height, img_width = 224,224
batch_size = 32
trains_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split = 0.2,
subset = "training",
seed = 42,
label_mode = 'categorical',
image_size = (img_height, img_width),
batch_size = batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=42,
label_mode = 'categorical',
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = trains_ds.class_names
resnet_model = Sequential()
pretrained_model = tf.keras.applications.ResNet50(include_top=False,
input_shape=(224,224,3),
pooling='avg',
classes = 20,
weights = 'imagenet')
for layer in pretrained_model.layers:
layer.trainable=False
resnet_model.add(pretrained_model)
resnet_model.add(Flatten())
resnet_model.add(Dense(512, activation='relu'))
resnet_model.add(Dense(20,activation='softmax'))
resnet_model.summary()
resnet_model.compile(optimizer=Adam(learning_rate=0.001),loss='categorical_crossentropy',metrics=['accuracy'])
epochs = 5
history= resnet_model.fit(
trains_ds,
validation_data=val_ds,
epochs=epochs)
image = cv2.imread(str(rock[0]))
image_resized = cv2.resize(image,(img_height,img_width))
print(img_height)
print(img_width)
image=np.expand_dims(image_resized,axis=0)
print(image.shape)
pred=resnet_model.predict(image)
print(pred)
output_class=class_names[np.argmax(pred)]
print("The predicted class is", output_class)
There are 20 classes. Train the model, the results looks great. The train accuracy and the val accuracy are almost 100%. But when I choose some train file from the train_ds and some files from the val_ds, the predicit results are bad. I don't know why. Is this model overfitting? or?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
