Gumbel Softmax

Intro

I have seen Gumbel Softmax (Jang et al) pop up multiple times, and most recently I saw it in the original DALLE paper. Its a neat idea that solves what seems to be a trivial problem. However, this approach is not trivial and its another example of cool math being cleverly applied.

Problem Description

Lets look at where in DALLE Gumbel is used to motivate this problem. They want to maximize lnpθ,ψ(x,y)\ln{p_{\theta, \psi}(x,y)} and by the ELBO, they derive the lower bound to be:

lnpθ,ψ(x,y)Ezqϕ(zx)[(lnpθ(xy,z)βDKL(qϕ(y,zx),pψ(y,z))]\ln{p_{\theta, \psi}(x,y)} \geq \mathbb{E}_{z\sim q_\phi(z|x)}[(\ln{p_\theta(x|y,z)}-\beta D_{KL}(q_\phi(y,z|x),p_\psi(y,z))]

Here zz is a sample from the categorical distribution defined by qϕ(zx)q_\phi(z|x). This is analogous to a VAE, where zz is a latent, but in a VAE the latent space is continuous and in DALLE’s case its discrete (the codebook). Therefore, we can’t slap the same reparametrization trick used in VAEs here. This defines our problem: we want to sample from a kk dimensional categorical distribution with unnormalized probabilities [π1,...,π2][\pi_1,...,\pi_2] and allow gradients to flow through.

Gumbel Softmax

The Gumbel Softmax heavily utilizes the results from (Maddison et al), where the authors present the Concrete distribution: Xk1X \in \triangle^{k-1}, where k1\triangle^{k-1} is a k1k-1  dimensional simplex. Intuitively, we want to sample from the vertices of this simplex based on our categorical distribution. The authors define the Concrete distribution as Xi=exp(log(πi)+gi)/τj=1kexp(log(πj)+gj)/τX_i= \frac{\exp(\log(\pi_i)+g_i)/\tau}{\sum_{j=1}^{k}\exp(\log(\pi_j)+g_j)/\tau} where giGumbel(0,1)g_i \sim Gumbel(0,1). We will discuss more about what the temperature τ\tau controls later. Based on this, the joint distribution over the simplex is:

pπ,τ(X)=(k1)!τk1i=1k(πiXiτ1i=1kπiXiτ)p_{\pi,\tau}(X)= (k-1)!\tau^{k-1}\prod_{i=1}^{k}(\frac{\pi_i{X_i}^{-\tau-1}}{\sum_{i=1}^{k}\pi_iX_i^{-\tau}})

Lets say we have a simple 3 dimensional categorical distribution,

Here are samples for various values of τ\tau for the unnormalized probabilities,[1,3,6][1,3,6], we can see from the equations and this diagram that τ\tau  controls whether the distribution is sparse or uniform. As τ0\tau \rightarrow 0 the distribution will approach sampling from the categorical with probabilities proportional to exp(π)\exp(\pi). Therefore, if we want to sample from a categorical all we need to do is sample from kk Gumbel random variables gig_i, and then using the above equation for XiX_i  we can construct a sample XX, and the gradients will not flow through gig_i and instead only through πi\pi_i which is a the output of a previous function. Thus we are all good to differentiate!

Code

import numpy as np
import plotly.express as px

# Gumbel-Softmax Sampler
def sample(pi, t):
    """
    pi: probability vector
    t: temperature
    """
    u = np.random.uniform(0, 1, len(pi))
    g = -np.log(-np.log(u))
    return np.exp((np.log(pi) + g) / t) / np.sum(np.exp((np.log(pi) + g) / t))

t = 3
n = 9999
pi = [1, 3, 6]

samples = sample(np.array(pi), t)
for i in range(n):
    samples = np.vstack((samples, sample(np.array(pi), t)))
# Plot and split samples by columns
fig = px.scatter_ternary(samples[:, 0], samples[:, 1], samples[:, 2], opacity=0.1)
fig.show()