'Variational auto encoder doesn't reconstruct accurate images

I have taken the example of variational autoencoder built in keras (keras.io/examples/generative/vae/) and I wanted to try it on my data. I obtained always the same figure which is nonsense compared to data, the construction error doesn't change even when I eliminated completely the kl term and tested on train image. mean of data = 0.0542, std = 0.00971

latent_dim = 2

encoder_inputs = keras.Input(shape=(32, 32, 1))
x = layers.Conv2D(16, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Flatten()(x)
x = layers.Dense(32, activation="relu")(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()
    
    
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(16 * 16 * 16, activation="relu")(latent_inputs)
x = layers.Reshape((16, 16, 16))(x)
x = layers.Conv2DTranspose(16, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(1, 3, activation="linear", padding="same"(x)

decoder = keras.Model(latent_inputs, x, name="decoder")

    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
        self.klweight = 0
        self.reset_states()
    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.mean_absolute_error(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss 
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(train, epochs=10, batch_size=32)
Epoch 1/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0708 - reconstruction_loss: 0.0704 - kl_loss: 1.3436e-08
Epoch 2/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0703 - reconstruction_loss: 0.0704 - kl_loss: 1.3436e-08
Epoch 3/10
882/882 [==============================] - 3s 4ms/step - loss: 0.0705 - reconstruction_loss: 0.0704 - kl_loss: 1.3436e-08
Epoch 4/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0705 - reconstruction_loss: 0.0704 - kl_loss: 1.3403e-08
Epoch 5/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0704 - reconstruction_loss: 0.0704 - kl_loss: 1.3403e-08
Epoch 6/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0705 - reconstruction_loss: 0.0704 - kl_loss: 1.3370e-08
Epoch 7/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0702 - reconstruction_loss: 0.0704 - kl_loss: 1.3436e-08
Epoch 8/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0704 - reconstruction_loss: 0.0704 - kl_loss: 1.3436e-08
Epoch 9/10
882/882 [==============================] - 3s 4ms/step - loss: 0.0702 - reconstruction_loss: 0.0704 - kl_loss: 1.3468e-08
Epoch 10/10
882/882 [==============================] - 3s 4ms/step - loss: 0.0702 - reconstruction_loss: 0.0704 - kl_loss: 1.3403e-08

Do you have any idea ?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source