'U-Net Binary Segmentation on RGB Images
I am trying to segment a RGB image using U-Net weights trained by segmentation models. However I keep getting an error: "WARNING:tensorflow:Model was constructed with shape (None, 256, 256, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name='data'), name='data', description="created by layer 'data'"), but it was called on an input with incompatible shape (None, 256, 256, 1)."
Here's my code, any help please?
from simple_unet_model import simple_unet_model
import cv2
import numpy as np
from keras.utils import normalize
from matplotlib import pyplot as plt
patch_size=256
def get_model():
return simple_unet_model(256, 256, 1)
def prediction(model, image, patch_size):
segm_img = np.zeros(image.shape[:2]) #Array with zeros to be filled with segmented values
patch_num=1
for i in range(0, image.shape[0], 256): #Steps of 256
for j in range(0, image.shape[1], 256): #Steps of 256
#print(i, j)
single_patch = image[i:i+patch_size, j:j+patch_size]
single_patch_norm = np.expand_dims(normalize(np.array(single_patch), axis=1),2)
single_patch_shape = single_patch_norm.shape[:2]
single_patch_input = np.expand_dims(single_patch_norm, 0)
single_patch_prediction = (model.predict(single_patch_input)[0,:,:,0] > 0.5).astype(np.uint8)
segm_img[i:i+single_patch_shape[0], j:j+single_patch_shape[1]] += cv2.resize(single_patch_prediction, single_patch_shape[::-1])
print("Finished processing patch number ", patch_num, " at position ", i,j)
patch_num+=1
return segm_img
##########
#Load model and predict
model = get_model()
#model.load_weights('mitochondria_gpu_tf1.4.hdf5')
model.load_weights('mitochondria_50_plus_100_epochs.hdf5')
#Large image
large_image = cv2.imread('data/01-1.tif', 0)
segmented_image = prediction(model, large_image, patch_size)
plt.hist(segmented_image.flatten()) #Threshold everything above 0
plt.imsave('data/results/segm.jpg', segmented_image, cmap='gray')
plt.figure(figsize=(8, 8))
plt.subplot(221)
plt.title('Large Image')
plt.imshow(large_image, cmap='gray')
plt.subplot(222)
plt.title('Prediction of large Image')
plt.imshow(segmented_image, cmap='gray')
plt.show()
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
