'why does the VQ-VAE require 2 Stage training?

According the the paper, VQ-VAE goes through two stage training. First to train the encoder and the vector quantization and then train an auto-regressive model for discrete estimation.

commitment_loss = self.beta * tf.reduce_mean(
            (tf.stop_gradient(quantized) - x) ** 2
        )
        codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
        self.add_loss(commitment_loss + codebook_loss)

Through out the training there is a reconstruction loss term with a VQ loss term

My question is why do we not replace the VQ posterior with the autoregressive model and give it the VQ-Loss and it should optimize for that same estimation.

I ask this because VQ process is a modified version of K-means clustering for a specific orientation. clustering could be estimated with generative models. Since an auto-regressive model is a generative model and it is used to estimate discrete values given encoded X and discrete values as labels for them. (I am referring the the second stage training).

A clustering generative model with the VQ Loss as a training signal should hypothetically reach the same goal.

I might be in over my head as I am no Deep learning or Statistics expert.

If my question is not clear enough please let me know.



Solution 1:[1]

I've been pondering this question as well, as I'm trying to do a generative art project using the VQ-VAE. The paper is unclear on why this two-stage pipeline is necessary, although I could hazard a guess - while the embeddings and quantizer part isn't yet settled, there is little point in trying to optimize an autoregressive estimator in the pixelCNN. But this doesn't explain why it isn't viable on the tail part of the training curve, when weights start to settle. So it is little more than conjecture on my part, I still have to wrap my head around the workings of what the role of the pixelCNN is in the first place. Did you get any further on this question?

Solution 2:[2]

The author of the paper, mr. van den Oord, kindly replied on this question, which I mailed to him. I'll try to phrase this as best as I can. The short answer is that my initial hunch was in the right direction:

  • As long as the weights on the VQVAE embeddings aren't settled, the pixelCNN would have to be retrained for each change to learn it.

He also elaborated on this a bit more:

  • Re-training the pixelCNN on unsettled embeddings would be less of an issue, if it weren't for the fact that this optimization part is computationally more expensive than the VQVAE training. So it's better to wait until you have a stable set of VQVAE weights.
  • When doing hyperparameter tuning, it's better to keep the embeddings stable in order to evaluate the pixelCNN results: pixelCNN evaluation results between different VQVAE embeddings are harder to compare.

HTH

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Rein
Solution 2 Rein