'Apply a transformation model (data augmentation) in images in Tensorflow
I am a newbie in some sequential models in Tensorflow with Python. I have a transformation sequential model like the one below. It applies randomly to a given image input some operations with random parameters.
import tensorflow as tf
from tensorflow.keras import layers
data_transformation = tf.keras.Sequential(
[
layers.Lambda(lambda x: my_random_brightness(x, 1, 20)))
layers.GaussianNoise(stddev=tf.random.uniform(shape=(), minval=0, maxval=1)),
layers.experimental.preprocessing.RandomRotation(
factor=0.01,
fill_mode="reflect",
interpolation="bilinear",
seed=None,
name=None,
fill_value=0.0,
),
layers.experimental.preprocessing.RandomZoom(
height_factor=(0.1, 0.2),
width_factor=(0.1, 0.2),
fill_mode="reflect",
interpolation="bilinear",
seed=None,
name=None,
fill_value=0.0,
),
]
)
There is also a lambda function in this model that define as below
def my_random_brightness(
image_to_be_transformed, brightness_factor_min, brightness_factor_max
):
# build the brightness factor
selected_brightness_factor = tf.random.uniform(
(), minval=brightness_factor_min, maxval=brightness_factor_max
)
c0 = image_to_be_transformed[:, :, :, 0] + selected_brightness_factor
c1 = image_to_be_transformed[:, :, :, 1] + selected_brightness_factor
c2 = image_to_be_transformed[:, :, :, 2] + selected_brightness_factor
image_to_be_transformed = tf.concat(
[c0[..., tf.newaxis], image_to_be_transformed[:, :, :, 1:]], axis=-1
)
image_to_be_transformed = tf.concat(
[
image_to_be_transformed[:, :, :, 0][..., tf.newaxis],
c1[..., tf.newaxis],
image_to_be_transformed[:, :, :, 2][..., tf.newaxis],
],
axis=-1,
)
image_to_be_transformed = tf.concat(
[image_to_be_transformed[:, :, :, :2], c2[..., tf.newaxis]], axis=-1
)
return image_to_be_transformed
Just now suppose that I would like to apply such a model to output such random operations in one batch containing just one image and I would like to see and save the result. How is that possible to do that? is there any predict() or flow() like function to output such a result?
EDIT: I tried result=data_transformation(image) and I have the following error:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Index out of range using input dim 3; input has only 3 dims [Op:StridedSlice] name: sequential/lambda/strided_slice/
Solution 1:[1]
Apart from the correctness of the brightness processing layer (above), it's coded to take a batch of images and not a single image. That's the reason it gives the expected error. You should add a batch axis while passing a single image in this case. It should work.
result=data_transformation(image[None, ...])
Also, in custom layer implementation, try always to adopt subclassing way.
class MyCustomBrightNess(layers.Layer):
def __init__(self, pbrightness_factor_min, brightness_factor_max, **kwargs):
super().__init__(**kwargs)
self.brightness_factor_max = brightness_factor_max
self.pbrightness_factor_min = pbrightness_factor_min
def call(self, inputs):
# build the brightness factor
selected_brightness_factor = tf.random.uniform(
(), minval=self.brightness_factor_min, maxval=self.brightness_factor_max
)
c0 = inputs[:, :, :, 0] + selected_brightness_factor
c1 = inputs[:, :, :, 1] + selected_brightness_factor
c2 = inputs[:, :, :, 2] + selected_brightness_factor
inputs = tf.concat(
[c0[..., tf.newaxis], inputs[:, :, :, 1:]], axis=-1
)
inputs = tf.concat(
[
inputs[:, :, :, 0][..., tf.newaxis],
c1[..., tf.newaxis],
inputs[:, :, :, 2][..., tf.newaxis],
],
axis=-1,
)
inputs = tf.concat(
[inputs[:, :, :, :2], c2[..., tf.newaxis]], axis=-1
)
return inputs
def get_config(self):
config = {
'pbrightness_factor_min': self.pbrightness_factor_min,
'brightness_factor_max': self.brightness_factor_max
}
base_config = super(MyCustomBrightNess, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
About the correctness of this implementation, I didn't check rigorously. I would suggest using random_brightness or adjust_brightness from the official implementation. Or if you're using tensorflow2.9, say hello to the new KerasCV, there we can find RandomBrightness layers.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | M.Innat |
