'tfrecords: encoding images results in distorted images

I am trying to make a tfrecords from image_dataset_from_directory; but when I try to visualize images to check if the encoding was correct, the images turn out to be distorted of some kind.

How I created the tfrecord:

Step 1: create dataset using image_dataset_from_directory

data_dir = 'path to JPG dataset'

load_split = partial(
    tf.keras.preprocessing.image_dataset_from_directory,
    data_dir,
    validation_split=0.2,
    shuffle=True,
    seed=123,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=1,
)

ds_train = load_split(subset='training')
ds_valid = load_split(subset='validation')

Step 2: encoding functions

def process_image(image, label):
    image = tf.image.convert_image_dtype(image, dtype=tf.uint8)
    image = tf.io.encode_jpeg(image)
    
    label = tf.one_hot(label, NUM_CLASSES)
    
    return image, label

def make_example(encoded_image, label):
    image_feature = Feature(
        bytes_list=BytesList(value=[
            encoded_image,
        ]),
    )
    label_feature = Feature(
        float_list=FloatList(value=label)
    )

    features = Features(feature={
        'image': image_feature,
        'label': label_feature,
    })
    
    example = Example(features=features)
    
    return example.SerializeToString()

Step 3: encoding and creating tfrecord

ds_train_encoded = (
    ds_train
    .unbatch()
    .map(process_image)
)

ds_valid_encoded = (
    ds_valid
    .unbatch()
    .map(process_image)
)

ds_train_encoded_iter = (
    ds_train_encoded
    .as_numpy_iterator()
)
with tf.io.TFRecordWriter(path='train.tfrecord') as f: # you can pass gs:// path here :) 
    for encoded_image, label in ds_train_encoded_iter:
        example = make_example(encoded_image, label)
        f.write(example)

ds_valid_encoded_iter = (
    ds_valid_encoded
    .as_numpy_iterator()
)
with tf.io.TFRecordWriter(path='/home/et/medai/images/tfrecords/test.tfrecord') as f:
    for encoded_image, label in ds_valid_encoded_iter:
        example = make_example(encoded_image, label)
        f.write(example)

How I tried to visualize the images in the tfrecords

Step 1: decoding functions

def _parse_image_function(example):
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([40], tf.float32),
    }

    features = tf.io.parse_single_example(example, image_feature_description)
    image = tf.image.decode_jpeg(features['image'], channels=3)
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    # image = features['image']
    label = features['label']

    return image, label


def read_dataset(filename, batch_size):
    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.map(_parse_image_function, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(500)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    # dataset = dataset.repeat()
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset

Step 2: decode and display

x = read_dataset('/home/et/medai/images/tfrecords/tests_train.tfrecord', 32)

plt.figure(figsize=(10, 10))
batch_size = 32
for images, labels in x.take(1):
    for i in range(batch_size):
        # display.display(display.Image(data=images[i].numpy()))

        ax = plt.subplot(6, 6, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.axis("off")

The result is something distorted: https://i.stack.imgur.com/tCAik.jpg

I am not quite sure where this distortion comes from. Original images look like this:

https://i.stack.imgur.com/Zi4HG.png

Any ideas?



Solution 1:[1]

I had a similar issue with you. I fixed my problem with image normalization in image processing (for your case, in process_image).

When you use 0~255 as pixel data, it tends to break up while manipulating image data, such as changing into byte and resizing because these manipulations round its pixel value. So, I want you to try normalizing your image pixel data into 0. to 1. of float values.

I used OpenCV to fix this, and I hope you can figure out your problem in a similar way as did refer to the code I posted below.

# This line distorted my images.
img = cv2.normalize(img, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)

# I changed to this line, and it worked.
img = cv2.normalize(img, None, alpha=0., beta=1., norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Jinhwan Sul