Softmax, mixtures and exponential families

Let’s take a look at a sentence from the variational classification paper.
Author

Paweł Czyż

Published

March 26, 2024

Recently, Carl kindly explained his variational classification paper to me. In particular, he recalled the following sentence: “the softmax layer can be interpreted as applying Bayes’ rule (…), assuming that the variables follow exponential family class-conditional distributions”.

I very much like this observation (as well as the paper), but I did not understand at all why this was true: isn’t that too powerful? Let’s try to rewrite it, so I understand it better: consider a space of features \(\mathcal X\) and a space of labels \(\mathcal Y = \{1, 2, \dotsc, L\}\).

We want the conditional distributions \(P(X\mid Y=y)\) to have PDFs with respect to some nice reference measure \(\mu\) on \(\mathcal X\) and we will assume that these PDFs are positive everywhere. For example, (non-singular) multivariate normal and Student distributions have this property on \(\mathbb R^n\) (but truncated normal distributions generally do not).

Then, we can write \[ p_{X\mid Y}(x\mid y) = \exp(\log p_{X\mid Y}(x\mid y)) = f(x)\cdot \exp\!\big( \langle \eta_y, T(x) \rangle \big), \]

where \(T(x) = \big(\log p_{X\mid Y}(x\mid y) \big)_{y=1,\dotsc, L} \in \mathbb R^L\) is called the sufficient statistic; \(\eta_y \in \mathbb R^L\) are the \(y\)-th standard basis vectors in \(\mathbb R^L\) (i.e., the one-hot encoding) forming the natural parameters, and \(f(x)=1\) is there just to make the formula look more familiar: it turns out that if the conditional distributions are fully supported, then they have to form1 an exponential family!

So, in a way, whenever we have positive densities, we need to have an exponential family. We can transpose this statement using the quote from the beginning: whenever we have positive probabilities, we need to have softmax! Namely, \[\begin{align*} p_{Y\mid X}(y\mid x) &= \frac{ p_{X\mid Y}(x\mid y)\, p_Y(y) }{ \sum_{y'} p_{X\mid Y}(x\mid y')\, p_Y(y') } \\ &= \mathrm{softmax}( \log p_{X\mid Y}(x\mid \bullet) + \log p_Y(\bullet) ). \end{align*} \]

Footnotes

  1. It may seem that we are missing the log-partition function, \(A_y\), but this is indeed the case: \[\begin{align*} A_y &= \log \int_{\mathcal X} f(x) \exp(\langle \eta_y, T(x)\rangle) \, \mathrm{d}\mu(x) \\ &= \log \int_{\mathcal X} p_{X\mid Y}(x\mid y) \, \mathrm{d} \mu(x) \\ &= \log 1 = 0. \end{align*} \]↩︎