'Tensorflow 2.x semantic segmentation convert mask from logits to image

I am using DataLoader to read in images and masks for semantic segmentation with two class (background and line). I define palette like

palette= {
(0, 0, 0) : 0 , # background (black)
(255, 255, 255) : 1  # line (white)
}

And read the images and masks using

dataset = DataLoader(image_paths=image_paths,
                 mask_paths=mask_paths,
                 image_size=(IMG_SIZE, IMG_SIZE),
                 crop_percent=0.1,
                 channels=(3, 1),
                 augment=False,
                 compose=False,
                 one_hot_encoding=True,
                 palette = palette,
                 seed=SEED)

After this I get the following in dataset

<BatchDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), TensorSpec(shape=<unknown>, dtype=tf.uint8, name=None))>

How can I convert the mask from logit to image? If I try

for images, masks in dataset.take(1):
     sample_image, sample_mask = images[0], masks[0]

It gives the error

InvalidArgumentError: ValueError: Tensor conversion requested dtype uint8 for Tensor with dtype float32: <tf.Tensor: shape=(128, 128, 2), dtype=float32, numpy=
array([[[1., 0.],
    [1., 0.],
    [1., 0.],
    ...,

How can I display images from this dataset using a function like this:

def display(display_list):
  plt.figure(figsize=(5, 5))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source