Gumbel-Softmax Trick
Thu 30 August 2018
Gumbel Distribution
cdf:
pdf:
mean:
where \(\gamma\approx0.5772\) is the Euler-Mascheroni constant.
Sanity Check
-
Check that \(\lim_{x\to-\infty}F(x;\mu)=0\) and \(\lim_{x\to\infty}F(x;\mu)=1\)
-
Check that \(\frac{\partial F(x;\mu)}{\partial x}=f(x;\mu)\)
$$\begin{aligned} \frac{\partial F(x;\mu)}{\partial x} & =e^{-e^{-(x-\mu)}}\left(-\frac{\partial e^{-(x-\mu)}}{\partial x}\right)\\ & =e^{-e^{-(x-\mu)}}e^{-(x-\mu)}\\ & =e^{-(x-\mu)-e^{-(x-\mu)}}\\ & =f(x;\mu)\end{aligned}$$
Sampling
Note that the inverse-cdf is given in a closed form:
Translation
Note that the pdf of Gumbel distribution only depends on \((x-\mu)\). Therefore, \(\text{Gumbel}(\mu)=\mu+\text{Gumbel}(0)\), in similar sense as normal distribution.
Gumbel-Softmax Trick
Maximum of Multiple Gumbel Random Variables
Let \(X_{k}\sim\text{Gumbel}(\mu_{k})\) be independent Gumbel random variables with \(\mu_{k}\)'s. The probability of \(K\)-th random variable \(X_{K}\) being the maximum among others, given that \(X_{K}=x_{K}\) is:
Removing the condition of having specific values,
where, \(M=\log\sum_{k}\exp\mu_{k}\).
So, the probability of \(K\)-th Gumbel random variable being the maximum among others is given as the softmax function of the \(\mu_{k}\)'s.
Sampling from Categorical Distribution using Gumbel-Max
Consider a \(m\)-class categorical random variable: \(Y\sim\text{Categorical}(\pi_{1},\dots,\pi_{m})\). Note that we can build the same random variable using \(m\) Gumbel random variables. In particular, let \(X_{k}\sim\text{Gumbel}(\log\pi_{k})=\log\pi_{k}+\text{Gumbel}(0)\). Then we define \(Y:=\arg\max_{k}X_{k}\), which yields an equivalent distribution.
Motivation of the Gumbel-Softmax Trick
But why do we need this equivalent Gumbel-Max representation given that the original Categorical distribution is simple enough? In many machine learning problems, we often approximate the probability density with Monte Carlo samples. Further, we often want to differentiate some term with respect to the probability density (or its parameters). But this becomes tricky if the density has been replaced by the Monte Carlo samples.
To illustrate, let us review the idea of Stochastic Gradient Variational Bayes (SGVB) algorithm suggested by Kingma et al. In variational Bayes problem, the evidence (variational) lower bound (ELBO) is given as:
Here, computing the expectation over \(Z\sim q(z;\phi)\) is intractable thus replaced by the MC samples \(\{z_{j}\}\). The goal is to find \(\phi\) that maximizes \(\mathcal{L}(\phi)\). But then, to compute \(\partial\mathcal{L}(\phi)/\partial\phi\) to apply the gradient ascent method, we confront a problem of differentiating over MC samples.
SGVB suggests reparameterization of \(q(z;\phi)\) if possible. In case of \(Z\sim q(z;\phi)=\mathcal{N}(z;\mu,\sigma^{2})\), it can be equivalently represented as \(Z=\mu+\sigma\epsilon\), \(\epsilon\sim\mathcal{N}(0,1)\) via reparameterization. Now, \(Z\) is no longer a primary random variable, rather it is derived from a random variable \(\epsilon\).
Thus, we can safely compute \(\partial\mathcal{L}(\phi)/\partial\phi\), \(\phi=(\mu,\sigma)\) as there are only MC samples of \(\epsilon\). In other words, the sampling effect is now separated out from the differentiating chain.
Similarly, consider a set of categorical variables \(\{Y_{i}\}^{D}\), \(Y_{i}\sim\text{Categorical}(\pi_{1}^{(i)},\dots,\pi_{m}^{(i)})\) and a function \(f(Y_{1},\dots,Y_{D})\). The goal is to find the optimal parameters \(\{\pi_{1}^{(i)},\dots,\pi_{m}^{(i)}\}^{D}\) that minimize a loss function \(\mathcal{L}=E_{Y_{1},\dots,Y_{D}}[f(Y_{1},\dots,Y_{D})]\). Here, computing the expectation is intractable for large \(D\), thus replaced by MC samples. To apply the gradient descent method, we confront the same problem of differentiating over MC samples. But consider using Gumbel-Max representation \(Y_{i}=\arg\max_{k}X_{k}^{(i)}=\arg\max_{k}(\log\pi_{k}^{(i)}+g_{k}^{(i)})\), \(g_{k}^{(i)}\sim\text{Gumbel}(0)\). Then, \(Y_{i}\)'s are no longer primary random variables, rather they are derived from random variables \(g_{k}^{(i)}\). Thus, we can safely compute \(\partial\mathcal{L}/\partial\pi_{k}^{(i)}\) except the \(\arg\max\) part.
Relaxation of Max with Softmax
For practical use, we would like to relax the \(\arg\max\) part. To this end, consider a one-hot vector \(V\) instead of scalar-valued \(Y\), where its elements are defined as \(V_{k}:=\mathbb{I}[k=Y]\). Note that \(V\) presents a vertex of \((m-1)\)-simplex. To approximate \(V\) to a similar element in the \((m-1)\)-simplex, \(\tilde{V}_{K}\), "soft"-max is adopted instead of max:
where \(\tau\) controls the continuous-discrete degree (smoother as \(\tau\) gets larger). Now, as the softmax function is differentiable, we can apply the gradient descent method to learn the parameters \(\{\pi_{1}^{(i)},\dots,\pi_{m}^{(i)}\}^{D}\).
Category: Bayesian Machine Learning Tagged: Reparametrization Trick Gumbel