'ValueError: logits and labels must have the same shape ((None, 128, 128, 1) vs (None, 1))
I'm trying to build a simple binary image classifier. Initially, my data looked like this
X_train.shape: (1421, 128, 128, 3)
X_test.shape : (356, 128, 128, 3)
y_train.shape: (1421,)
y_test.shape : (356,)
tried to reshape data with
X_train = X_train.reshape(-1, img_size, img_size, 3)
X_test = X_test.reshape(-1, img_size, img_size, 3)
y_train = y_train.reshape(-1, 1)
y_test = y_test.reshape(-1, 1)
and result updated to
X_train.shape: (1421, 128, 128, 3)
X_test.shape : (356, 128, 128, 3)
y_train.shape: (1421, 1)
y_test.shape : (356, 1)
model
model = Sequential([
Dense(64, activation='relu', input_shape=(X_train[0].shape)),
Dense(32, activation='relu'),
Dense(32),
Dense(1, activation='sigmoid')
])
early_stopping = keras.callbacks.EarlyStopping(
patience=10,
min_delta=0.001,
restore_best_weights=True,
)
history = model.fit(
X_train, y_train,
validation_data=(X_test, y_test),
batch_size=512,
epochs=10,
callbacks=[early_stopping],
verbose=2
)
and got the error
ValueError: logits and labels must have the same shape ((None, 128, 128, 1) vs (None, 1))
Solution 1:[1]
You need to use a Flatten layer with Dense layer.
model = tf.keras.Sequential([
Flatten(input_shape=(128, 128, 3))),
Dense(64, activation='relu'),
Dense(32, activation='relu'),
Dense(32),
Dense(1, activation='sigmoid')
])
Or can use convolutional layers in your model as given below:
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
and then compile the model.
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=['accuracy'])
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | TFer2 |
