blogs

The Gumbel-Softmax Trick: Making Discrete Sampling Differentiable

A problem that haunted generative modeling for years: how do you backpropagate through a discrete sampling operation?

This post explores how a simple reparameterization enables gradients to flow through categorical sampling, unlocking end-to-end training with discrete latent variables.

In models like DALL-E1, the goal is to maximize the joint likelihood lnpθ,ψ(x,y)\ln p_{\theta,\psi}(x,y) of images and text. Using the standard ELBO derivation, we get a familiar-looking lower bound:

lnpθ,ψ(x,y)Ezqϕ(zx)[lnpθ(xy,z)]βDKL(qϕ(y,zx)pψ(y,z))ln p_{\theta,\psi}(x,y) \geq \mathbb{E}_{z \sim q_{\phi}(z|x)}[\ln p_{\theta}(x|y,z)] - \beta D_{KL}(q_{\phi}(y,z|x) \parallel p_{\psi}(y,z))

Here’s the twist: unlike a VAE’s continuous latent space, (z) here is sampled from a categorical distribution qϕ(zx)q_{\phi}(z|x). It’s discrete. And that breaks everything. The sampling operation creates a hard stop in the computational graph—no gradients can flow back through it. So while we could compute the ELBO, we couldn’t optimize it end-to-end.

The solution? A clever reparameterization trick, but for categoricals instead of Gaussians. That’s where the Gumbel-Softmax comes in.

The Core Idea: Continuous Relaxation

The key insight comes from Maddison et al.’s Concrete distribution2. Instead of sampling a one-hot vector directly, we sample from a continuous distribution over the simplex Δk1\Delta^{k-1} (the set of kk-dimensional probability vectors). Think of it as sampling near the vertices of a simplex, where each vertex represents a pure category.

The process is surprisingly elegant. Given unnormalized logits π1,,πk\pi_1, \dots, \pi_k:

  1. Add Gumbel noise: Sample giGumbel(0,1)g_i \sim \text{Gumbel}(0,1) and add it to each logπi\log \pi_i. This creates perturbed logits that are stochastic but still differentiable.
  2. Softmax with temperature: Apply a tempered softmax to get a “soft” one-hot vector:
Xi=exp(logπi+gi)/τj=1kexp(logπj+gj)/τX_i = \frac{\exp\left(\log \pi_i + g_i\right)/\tau}{\sum_{j=1}^k \exp\left(\log \pi_j + g_j\right)/\tau}

The temperature τ>0\tau > 0 controls how sharply the distribution peaks. As τ0\tau \to 0, the samples approach true one-hot vectors; as τ\tau grows, they become more uniform. During training, you typically anneal τ\tau from high to low, trading off gradient variance for fidelity.

The resulting Concrete distribution over the simplex has this joint density:

pπ,τ(X)=(k1)!τk1i=1k(πiXiτ1j=1kπjXjτ)p_{\pi,\tau}(X) = (k-1)! \, \tau^{k-1} \prod_{i=1}^k \left(\frac{\pi_i X_i^{-\tau-1}}{\sum_{j=1}^k \pi_j X_j^{-\tau}}\right)

The formula is less important than what it does: it gives us a way to sample from a categorical distribution while keeping the gradients intact.

In practice we often use the Gumbel–Softmax (Concrete) relaxation during training to obtain differentiable, soft one-hot samples (with a temperature schedule), then discretize (argmax) at inference; alternatively, people use a straight-through estimator (hard forward, soft backward) during training. This small hack enables many models with discrete choices VAEs with discrete latents, hard-attention models, etc. though large text-to-image systems sometimes rely on discrete codebooks or other techniques instead of Gumbel relaxations.

Footnotes

  1. Ramesh, A., et al. (2021). Zero-Shot Text-to-Image Generation*. OpenAI.

  2. Maddison, C. J., et al. (2017). The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables.