'U-Net Binary Segmentation on RGB Images

I am trying to segment a RGB image using U-Net weights trained by segmentation models. However I keep getting an error: "WARNING:tensorflow:Model was constructed with shape (None, 256, 256, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name='data'), name='data', description="created by layer 'data'"), but it was called on an input with incompatible shape (None, 256, 256, 1)."

Here's my code, any help please?

from simple_unet_model import simple_unet_model 
import cv2
import numpy as np
from keras.utils import normalize
from matplotlib import pyplot as plt

patch_size=256


def get_model():
    return simple_unet_model(256, 256, 1)



def prediction(model, image, patch_size):
    segm_img = np.zeros(image.shape[:2])  #Array with zeros to be filled with segmented values
    patch_num=1
    for i in range(0, image.shape[0], 256):   #Steps of 256
        for j in range(0, image.shape[1], 256):  #Steps of 256
            #print(i, j)
            single_patch = image[i:i+patch_size, j:j+patch_size]
            single_patch_norm = np.expand_dims(normalize(np.array(single_patch), axis=1),2)
            single_patch_shape = single_patch_norm.shape[:2]
            single_patch_input = np.expand_dims(single_patch_norm, 0)
            single_patch_prediction = (model.predict(single_patch_input)[0,:,:,0] > 0.5).astype(np.uint8)
            segm_img[i:i+single_patch_shape[0], j:j+single_patch_shape[1]] += cv2.resize(single_patch_prediction, single_patch_shape[::-1])
          
            print("Finished processing patch number ", patch_num, " at position ", i,j)
            patch_num+=1
    return segm_img

##########
#Load model and predict
model = get_model()
#model.load_weights('mitochondria_gpu_tf1.4.hdf5')
model.load_weights('mitochondria_50_plus_100_epochs.hdf5')

#Large image
large_image = cv2.imread('data/01-1.tif', 0)
segmented_image = prediction(model, large_image, patch_size)
plt.hist(segmented_image.flatten())  #Threshold everything above 0

plt.imsave('data/results/segm.jpg', segmented_image, cmap='gray')

plt.figure(figsize=(8, 8))
plt.subplot(221)
plt.title('Large Image')
plt.imshow(large_image, cmap='gray')
plt.subplot(222)
plt.title('Prediction of large Image')
plt.imshow(segmented_image, cmap='gray')
plt.show()


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source