'TF pipeline to dynamically extract patches and flatten dataset

I was trying to train an autoencoder on image patches. My training data consists of single-channel images loaded into a numpy array with shape [10000, 256, 512, 1]. I know how to extract patches from the images but it is rather non-intuitive that the batches select images and thus the number of points per batch depends on how many patches are extracted per image. If 32 patches are extracted per image, I'd like the dataset to behave as if it were [320000, 256, 512, 1] so that shuffling and batches pull from several images at a time but with the patches extracted on the fly so that this doesn't have to be kept in memory.

The closest question I've seen around is Load tensorflow images and create patches but, as I've mentioned, it doesn't provide what I want.

PATCH_SIZE = 64

def extract_patches(imgs, patch_size=PATCH_SIZE, stride=PATCH_SIZE//2):
    # extract patches and reshape them into patch images
    n_channels = imgs.shape[-1]
    if len(imgs.shape) < 4:
        imgs = tf.expand_dims(imgs, axis=0)  
    return tf.reshape(tf.image.extract_patches(imgs,
                                               sizes=[1, patch_size, patch_size, n_channels],
                                               strides=[1, stride, stride, n_channels],
                                               rates=[1, 1, 1, 1],
                                               padding='VALID'),
                      (-1, patch_size, patch_size, n_channels))

batch_size = 8
dataset = (tf.data.Dataset.from_tensor_slices(tf.cast(imgs, tf.float32))
            .map(extract_patches, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
            .shuffle(10*batch_size, reshuffle_each_iteration=True)
            .batch(batch_size)
            )

creates a dataset that returns batches with shape (batch_size, 105, 64, 64, 1) whereas I want a rank 4 tensor with shape (batch_size, 64, 64, 1) and shuffle to operate on patches (rather than collections of patches for each image). If I put .map at the end of the pipeline

batch_size = 8
dataset = (tf.data.Dataset.from_tensor_slices(tf.cast(imgs, tf.float32))
            .shuffle(10*batch_size, reshuffle_each_iteration=True)
            .batch(batch_size)
            .map(extract_patches, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)
            )

This does flatten the batches and returns a rank 4 tensor, but in this case each batch has shape (840, 64, 64, 1).



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source