Decision Transformer: Reinforcement Learning via Sequence Modeling by Lili Chen et al. explained in 5 minutes.
⭐️Paper difficulty: 🌕🌕🌕🌕🌑
🎯 At a glance:
Wouldn’t it be amazing if you could simply type a text prompt describing the image in as much or as little detail as you want and a bunch of images fitting the description were generated on the fly? Well, thanks to the good folks at OpenAI it is possible! Introducing their DALL-E model that uses a discrete visual codebook obtained by training a discrete VAE, and a transformer to model the joint probability of text prompts and their corresponding images. And if that was not cool enough, they also make it possible to use an input image alongside a special text prompt as an additional condition to perform zero-shot image-to-image translation.
(Highly recommended reading to understand the core ideas in this paper):
🔍 Main Ideas:
1) Two stage approach:
Training a transformer on raw pixels is prohibitively expensive in terms of required memory, hence similar to what is done in pretty much all other vision transformer models, the authors of DALL-E propose to first learn a compact embedding that captures global concepts in images. This helps reduce the number of tokens for the transformer to process (256x256 RGB image becomes a 32x32 grid of tokens), and capture long-range dependencies between distant parts of generated images.
In the second stage an autoregressive transformer processes a concatenation of 256 byte-pair-encoded (BPE) text tokens and the 32x32 image tokens and learns to model their joint distribution.
The whole pipeline is trained by maximizing the evidence lower bound (ELB) on the joint likelihood of the model distribution over images, captions, and the tokens for the encoded RGB image.
2) dVAE - learning the visual codebook: In the first stage of the training pipeline the dVAE (Resnet encoder-decoder with 1x1 convolutions before and after quantized latent codes) is trained on images alone. The discrete codebook contains 8192 vectors. The distribution of codes is discrete, and sampling from it requires using argmax, which makes it impossible for gradients to flow through. Hence, the authors use a gumbel-softmax trick that adds a temperature parameter to the softmax function so that it starts with a uniform distribution and as temperature approaches zero, the sampling slowly approaches argmax, and the vectors - one-hot. This temperature annealing trick helps close the gap between the argmax approximation during training and true argmax at inference time.
3) Transformer - learning the prior: At this point the dVAE is fixed, and the transformer is trained by tokens argmax sampled from it. The transformer is a decoder-only model that uses different kinds of attention masks: causal for text-to-text, and either row, column or convolutional for image-to-image. It can attent to all text tokens in any of its 64 self-attention layers. Since the number of text tokens is limited to 256, it was decided to learn a special padding token for each of the text positions that is used if no token is available.
4) Generation: At inference time the authors rerank the samples drawn from the model using a pretrained contrastive model (CLIP) that scores how well a a caption matches a candidate image. They sample 512 images, rerank them according to the predicted scores, and select the top 32.
5) Engineering tricks: A good portion of the paper is dedicated to insights on how to train very large models on many (64 16GB NVIDIA V100) GPUs, and since it is unlikely that many people have access to such a setup I won’t cover those tips and tricks here.
📈 Experiment insights / Key takeaways:
- DALL-E obtains FID within 2 points of the best prior approach on MS-COCO in a zero-shot setting
- It outperformed (93 to 90) DF-GAN in a user study on how well the image matches the text prompt
- For very specific datasets like CUB the zero-shot approach is far from ideal, and most likely requires fine-tuning
- Not only is DALL-E is able to handle some bizarre prompts without breaking down, it can create complex scenes with a high level of abstraction, e.g. “an illustration of a baby hedgehog in a Christmas sweater walking a dog”.
- DALL-E has some capability to do zero-shot image-to-image translation, e.g. changing the color of an image, flipping it around, doing style transfer when passed the tokens for the top half of an image grid (the top half contains the image, the bottom is empty) as an input along with a text prompt “the exact same cat on the top as a sketch at the bottom”
🖼️ Paper Poster:
- DALL-E is a pretty dank name (5/5), but why is the model even called that? It isn’t mentioned once in the whole paper except for the link to the code 😂
- I am a bit conflicted about this paper. On the one hand the results are some of the most exciting that I have seen in a while, yet the method doesn’t have any big wow ideas, and the main focus seems to be on the amount of data and engineering tricks to make the model train on a ton of GPUs in parallel
DALL-E arxiv / DALL-E github
👋 Thanks for reading!
Join Patreon for Exclusive Perks!
If you found this paper digest useful, subscribe and share the post with your friends and colleagues to support Casual GAN Papers!
Join the Casual GAN Papers telegram channel to stay up to date with new AI Papers!
P.S. Send me paper suggestions for future posts @KirillDemochkin!