Gumbel Softmax

2 min read

Problem

In DALLE, they wanted to maximize lnpθ,ψ(x,y)\ln{p_{\theta}, \psi (x,y)} and by ELBO, they derive the lower bound to be:

lnpθ,ψ(x,y)Ezqϕ(zx)[lnpθ(xy,z)]βDKL(qϕ(y,zx)pψ(y,z))\ln{p_{\theta}, \psi (x,y)} \geq \mathbb{E}_{z \sim q_{\phi}(z|x)}[\ln{p_{\theta}(x|y, z)}] - \beta D_{KL}(q_{\phi}(y,z|x)||p_{\psi}(y,z))

Here zz is a sample from the categorical distribution qϕ(zx)q_{\phi}(z|x), which is a discrete variable. This is analogous to a VAE, where zz is a latent, but in a VAE the latent space is continuos and in DALLE's case its discrete. The problem is that the sampling operation is not differentiable, so we cannot backpropagate through it. To solve this, we can use the same reparametrization trick used in VAEs, so they use the Gumbel Softmax trick. This defines our problem: We want to sample from a kk dimensional categorical distribution with unnormalized probabilities [π1,...,π2][{\pi}_1,...,{\pi}_2] and allow gradients to flow through.

Gumbel Softmax

The Gumbel Softmax heavily utilizes the results from (Maddison et al), where the authors present the Concrete distribution: Xk1X \in \triangle^{k-1}, where k1\triangle^{k-1} is the k1k-1 simplex, which is the set of all kk dimensional vectors with non-negative entries that sum to 1. The Concrete distribution is a continuous relaxation of the categorical distribution. Intuitively, we want to sample from the vertices of this simplex based on our categorical distribution. The authors define the Concrete distribution as:

Xi=exp(logπi+gi/τ)j=1kexp(logπj+gj/τ)X_i = \dfrac{\exp\left(\log{\pi_i} + g_i / \tau \right)}{\sum_{j=1}^k \exp\left(\log{\pi_j} + g_j / \tau \right)}

where giGumbel(0,1)g_i \sim \mathrm{Gumbel(0,1)} and temperature τ>0\tau > 0. Based on this, the joint distribution over the simplex is:

pπ,τ(X)=(k1)!τk1i=1k(πiXiτ1i=1kπiXτ)p_{\pi, \tau}(X) = (k-1)! \tau^{k-1} \prod_{i=1}^k \left(\dfrac{\pi_i X_i^{-\tau-1}}{\sum_{i=1}^k \pi_iX^{-\tau}}\right)

Footnotes

    Footnotes
  1. Zero-Shot Text-to-Image Generation by Aditya Ramesh et al. OpenAI 2021
  2. The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables ICLR 2017