A problem that haunted generative modeling for years: how do you backpropagate through a discrete sampling operation?
This post explores how a simple reparameterization enables gradients to flow through categorical sampling, unlocking end-to-end training with discrete latent variables.
In models like DALL-E1, the goal is to maximize the joint likelihood of images and text. Using the standard ELBO derivation, we get a familiar-looking lower bound:
Here’s the twist: unlike a VAE’s continuous latent space, (z) here is sampled from a categorical distribution . It’s discrete. And that breaks everything. The sampling operation creates a hard stop in the computational graph—no gradients can flow back through it. So while we could compute the ELBO, we couldn’t optimize it end-to-end.
The solution? A clever reparameterization trick, but for categoricals instead of Gaussians. That’s where the Gumbel-Softmax comes in.
The Core Idea: Continuous Relaxation
The key insight comes from Maddison et al.’s Concrete distribution2. Instead of sampling a one-hot vector directly, we sample from a continuous distribution over the simplex (the set of -dimensional probability vectors). Think of it as sampling near the vertices of a simplex, where each vertex represents a pure category.
The process is surprisingly elegant. Given unnormalized logits :
- Add Gumbel noise: Sample and add it to each . This creates perturbed logits that are stochastic but still differentiable.
- Softmax with temperature: Apply a tempered softmax to get a “soft” one-hot vector:
The temperature controls how sharply the distribution peaks. As , the samples approach true one-hot vectors; as grows, they become more uniform. During training, you typically anneal from high to low, trading off gradient variance for fidelity.
The resulting Concrete distribution over the simplex has this joint density:
The formula is less important than what it does: it gives us a way to sample from a categorical distribution while keeping the gradients intact.
In practice we often use the Gumbel–Softmax (Concrete) relaxation during training to obtain differentiable, soft one-hot samples (with a temperature schedule), then discretize (argmax) at inference; alternatively, people use a straight-through estimator (hard forward, soft backward) during training. This small hack enables many models with discrete choices VAEs with discrete latents, hard-attention models, etc. though large text-to-image systems sometimes rely on discrete codebooks or other techniques instead of Gumbel relaxations.