Introduction
This post is adapted from a series of generative modelling blog posts I wrote for my personal website.
Generative AI has been a subject of much interest lately, as researchers and practitioners continue to make significant strides in developing innovative methods for image synthesis and other creative tasks. One notable advancement has been the introduction of the stable diffusions model, which offers a controlled, step-by-step process for generating images. In this blog post, I’ll explore the use of VQ-VAE, a technique for obtaining compressed, discrete representations of images, and how we can leverage these representations in two ways: 1) by training a transformer to generate new images by next token prediction, and 2) by training a transformer to perform discrete diffusion, similar to the process stable diffusions uses. My aim is to provide you with a clearer picture of the exciting developments happening in the realm of generative AI.
Vector-quantized variational autoencoders
Introduced in the paper, Neural Discrete Representation Learning, VQ-VAE’s are an extension of ordinary variational encoders. I won’t cover them in a great deal of detail here, if you’d like to know more I’d recommend reading this blog post. In short, a VQ-VAE is an autoencoder that takes the continuous latent representation output by the encoder and maps it to a discrete representation using a vector quantization layer. Typically the encoder looks like a series of convolutional layers that reduce the spatial dimensions of the image while increasing the number of channels. For example, a sequence of such layers can map an image of size (3,128,128) to a tensor of size (1024,16,16).
In this case, we’d choose a VQ layer with an embedding size of 1024. We choose a set of random embedding vectors to make up a codebook, for this example assume we choose 256 such vectors. Thus the embedding layer has a codebook of size (256,1024). When we apply the discrete encoding to the (1024,16,16) dimension tensor we map each of the 1616 vectors to one of the 256 code-book vectors. Specifically, the code-book vector that is closest. Having done so we now have a discrete representation of the image of size (16,16,1) where the final dimension is an integer value corresponding to one of the code-book vectors. By adding a decoder network we can then try to reconstruct the original image from the discrete representation.
You train VQ-VAE’s in a similar way to ordinary variational autoencoders but with some extra details that I won’t cover here. I will mention that while training the model you can also train a discriminator at the same time that learns to differentiate between the real and reconstructed images. Instead of using the mean squared error for the reconstruction loss you can use a perceptual loss that is based on the discriminator. A perceptual loss takes the set of discriminator network activations of the real and reconstructed images and minimises the difference between them. This means that the reconstructed image is forced to have similar activations to the real image. Doing this means that the discriminator learns realistic features and the VQ-VAE learns to create images that activate those features.
This works pretty well, at least much better than variational autoencoders on their own (See fig 1.).
Fig 1. Real and reconstructed image pairs. Real images are on
the left, and reconstructed images are on the right.
Next, I’ll discuss how to use the latent vectors from a VQ-VAE to train a transformer for generative modelling. In particular, I’ll mention two main approaches. The first was introduced in the paper Taming Transformers for High-Resolution Image Synthesis and the second is somewhat exploratory, but roughly adapts diffusion models for discrete latent spaces using transformers and is inspired by the famous stable diffusions paper. (I later discovered that the approach taken here is a crude version of what they do in this paper: Vector Quantized Diffusion Model for Text-to-Image Synthesis)
Generative Modeling using Sequential Transformers
Transformers are a class of generative models that take a sequence of discrete values and predict the next token in the sequence. They’ve proven to be very effective at a variety of tasks including language modelling, image classification and image captioning. I won’t go into the details of how transformers work here, but if you’re interested in learning more I’d recommend the illustrated transformer blog post.
One idea you might have is to treat an image as a sequence of pixels and use their channel values (0-255) as the discrete token values. You can then train a transformer to do the next token prediction on the channel tokens themselves. This has been done and does work but doesn’t scale well to large images. This is due to the attention mechanism in the transformer which requires comparing every token in the sequence with every other token in the sequence. This means that for an image of size, (128,128,1) we’d have to compare 128128 tokens with each other. This is a lot of computation and is not feasible for large image sizes. However, using our VQ-VAEs we can reduce this problem.
By training a VQ-VAE we obtain a model that can be used to encode images into a discrete latent space. As an example, you might have an image of shape (128,128,3) and using the VQ-VAE we can encode it into a latent space of shape (1616 ,1). Assuming the code-book size is 256 then we’ve gone from 1281283 float32 tensors to 1616 int8 tensors. This means two things: firstly, the data now has a discrete representation and secondly, this representation is now much smaller than the original instance. Both of these properties are desirable for training transformers as now when we train the transformer, the attention mechanism will only have to compare 1616 tokens with each other. This is much more manageable and once we’ve generated a new sequence of tokens we can decode it back into the larger image space.
Training transformers is remarkably simple. The main complexity is in preprocessing the training data which requires converting each image into its discrete latent space representation. The transformer takes the sequence of tokens, (s_0, s_2, …, s_n) of shape (n, 1) and outputs a sequence of probability vectors of shape (n, 256). We then use the cross-entropy loss to compare the output probability vectors with the ground truth tokens. Importantly the sith token is used as the target for the pi-1th probability vector.
Once we’ve trained the transformer we can generate new images by feeding the transformer an initial random token and repeatedly predicting the next token in the sequence until we’ve generated all 16 16 tokens. We can then decode the sequence of tokens back into the original images space. I spent a while trying to figure out why my transformer seemed to train well but not generate good images only to discover that I had an off-by-one error in the code that sampled the next token in the sequence 🤦! Once I fixed this the images started to look much better.
Results
In each of Fig 2., Fig 3., and Fig 4. The first set of 3 by 3 images on the left is generated by the transformer in the latent space. The second set is the VQ-VAE reconstructions of original images from CelebA dataset. In Fig 2. I map an image to its discrete latent representation and then take the first 816 tokens from the discrete encoding and throw away the rest. I then use the trained latent space transformer to regenerate the missing tokens. Hence the resulting images match on the top half and vary on the bottom.
Fig 2. The first 3 by 3 set of 9 images on the left are the transformer
generations, and the second set of 3 by 3 images on the right is the
VQ-VAE reconstructions. The transformer uses the first 816 tokens
of the VQ-VAE encodings to generate the missing tokens.
Fig 3. The first 3 by 3 set of 9 images on the left are the transformer
generations, and the second set of 3 by 3 images on the right is the
VQ-VAE reconstructions. The transformer uses the first 416 tokens
of theVQ-VAE encodings to generate the missing tokens.
Fig 4. The first 3 by 3 set of 9 images on the left are the transformer
generations, and the second set of 3 by 3 images on the right is the
VQ-VAE reconstructions. The transformer uses the first token of the
VQ-VAE encodings to generate the missing tokens.
This idea has been applied to train generative image models most notably by StabilityAI when they trained the stable diffusion model. In the stable diffusion paper, first, they train a VQ-VAE, they then use the continuous outputs of the encoder prior to the vector quantization layer in order to train the diffusion model.
I wanted to try the above but use the discrete latent space generated after the vector quantized layer instead. A similar idea has been pursued in Vector Quantized Diffusion Model for Text-to-Image Synthesis. In order to do so, I use the same transformer architecture as in the previous case. To train the discrete diffusion model we first need to generate the training data. To do this we take the encoded discrete representation of an image, x, and randomly perturb a random percentage of the tokens in the sequence to get xa. Next, we further perturb xa to obtain xb by randomly changing a fixed percentage of the tokens this time. In my experiments, I perturbed 10% of the tokens between xa and xb. We then train the diffusion model, g , to minimize the categorical cross-entropy between xa and g(xb).
Results
To generate samples using the diffusion transformer we first generate a set of completely random tokens. When we pass these tokens through the VQ-VAE decoder we get the strange noisy images you can see in Fig 5, 6 and 7. We then pass these to the diffusion transformer which outputs a set of probabilities corresponding to how it thinks we should change the random tokens to make the tokens more like the data it has been trained on.
Fig 5, Fig 6 and Fig 7 shows the sequence of images generated in the reverse diffusion process. In Fig 5. example, we use a lower temperature when sampling the tokens in the next sequence. This means the model only changes a token when it is sure that it makes sense to do so. This results in longer generation times. In Fig 6. We use a higher temperature and thus generation is faster.
The final figure, Fig 7, shows an animated gif of the process occurring. You can see how the model pieces together the image like a jigsaw puzzle.
Fig 6. Reversed Discrete diffusion. The first image is a set of random
tokens, and each image after is generated from the gradually denoised
random tokens by iteratively passing through the diffusion transformer.
A higher sampling temperature means faster diffusion.
Fig 6. Gif animation of the reversed discrete diffusion process.
Conclusion
In comparison, the next token prediction transformer results are definitely better although I think some improvements to the process of training the discrete diffusion transformer are possible. What surprised me though, is how simple the diffusion process was to set up. I think it was maybe a two or three-line change in the training script.
Hopefully, this post has given you a high-level insight into how both these kinds of models work.
.
.
.
.