'Variational auto encoder doesn't reconstruct accurate images
I have taken the example of variational autoencoder built in keras (keras.io/examples/generative/vae/) and I wanted to try it on my data. I obtained always the same figure which is nonsense compared to data, the construction error doesn't change even when I eliminated completely the kl term and tested on train image. mean of data = 0.0542, std = 0.00971
latent_dim = 2
encoder_inputs = keras.Input(shape=(32, 32, 1))
x = layers.Conv2D(16, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Flatten()(x)
x = layers.Dense(32, activation="relu")(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(16 * 16 * 16, activation="relu")(latent_inputs)
x = layers.Reshape((16, 16, 16))(x)
x = layers.Conv2DTranspose(16, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(1, 3, activation="linear", padding="same"(x)
decoder = keras.Model(latent_inputs, x, name="decoder")
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
self.klweight = 0
self.reset_states()
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.mean_absolute_error(data, reconstruction), axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam())
vae.fit(train, epochs=10, batch_size=32)
Epoch 1/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0708 - reconstruction_loss: 0.0704 - kl_loss: 1.3436e-08
Epoch 2/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0703 - reconstruction_loss: 0.0704 - kl_loss: 1.3436e-08
Epoch 3/10
882/882 [==============================] - 3s 4ms/step - loss: 0.0705 - reconstruction_loss: 0.0704 - kl_loss: 1.3436e-08
Epoch 4/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0705 - reconstruction_loss: 0.0704 - kl_loss: 1.3403e-08
Epoch 5/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0704 - reconstruction_loss: 0.0704 - kl_loss: 1.3403e-08
Epoch 6/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0705 - reconstruction_loss: 0.0704 - kl_loss: 1.3370e-08
Epoch 7/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0702 - reconstruction_loss: 0.0704 - kl_loss: 1.3436e-08
Epoch 8/10
882/882 [==============================] - 4s 4ms/step - loss: 0.0704 - reconstruction_loss: 0.0704 - kl_loss: 1.3436e-08
Epoch 9/10
882/882 [==============================] - 3s 4ms/step - loss: 0.0702 - reconstruction_loss: 0.0704 - kl_loss: 1.3468e-08
Epoch 10/10
882/882 [==============================] - 3s 4ms/step - loss: 0.0702 - reconstruction_loss: 0.0704 - kl_loss: 1.3403e-08
Do you have any idea ?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
