2 minute read

Overview:

EM is a method for estimating parameters of a latent variable model. For most non-trivial problems, the use of latent variables is instrumental for simplifying calculations, and fomalising our own assumptions of cause-effect relationships thet led us to observe the given dataset.

The steps in EM can be broken down to variational inference for optimising an objective with respect to a distribution, followed by parameter optimisation, holding the distribution found from the previous step fixed.

Derivation:

Suppose you had observations/data collected in the set \(X\), latent variables collected in the set \(Z\) and parameters of the model \(\theta\). The set of possible values for \(Z\) is \(\mathcal{Z}\). Following maximum likelihood estimation, you would like to estimate the paramters using:

\[\begin{equation} \theta^*=\sup_{\theta\in \Theta}\quad \log p(X\mid \theta). \end{equation}\]

To make use of the latent variables we have:

\[\begin{align} \log p(X\mid \theta)&=\log \int_{\mathcal{Z}}p(X, Z \mid \theta)dZ\notag\\ &=\log \int_{\mathcal{Z}}q(Z)\frac{p(X, Z\mid \theta)}{q(Z)}dZ\notag\\ &=\log \mathbb{E}_{Z\sim q}\left[\frac{p(X, Z\mid \theta)}{q(Z)}\right]\notag\\ &\ge \mathbb{E}_{Z\sim q}\left[ \log \frac{p(X, Z\mid \theta)}{q(Z)} \right] && \text{(by Jensen)}\notag\\ &=:F(q, \theta) \end{align}\\\]

The last quantity, \(F(q, \theta)\), is known as the free energy or the Evidence Lower Bound (ELBO). The reason for the first naming is somewhat obscure, the latter is obvious since \(p(X \mid \theta)\) is known as the evidence when doing type 2 Maximum likelihood estimation.

E-step:

To derive the E-step we rewrite the free energy as:

\[\begin{align} F(q, \theta) &= \int_{\mathcal{Z}}q(Z)\log \frac{p(Z\mid X, \theta)}{q(Z)}dZ + \int_{\mathcal{Z}}q(Z)\log p(X\mid \theta)dZ\notag\\ &=\log p(X\mid \theta) - KL(q(Z) \mid\mid p(Z\mid X,\theta)). \end{align}\\\]

For the E-step we maximise the above expression with respect to \(q\). Using the fact that \(KL(q\mid\mid p)\ge 0\) and \(KL(q\mid\mid p)=0\iff q(z)=p(z) \quad\forall z\), we have that, the E-step sets \(q^{new}(Z)=p(Z\mid X,\theta)\). In which case the KL divergence is zero, and we have that \(F(q^{new}, \theta)=\log p(X\mid \theta)\). So the E-step makes the bound tight by doing an inference step (estimating the posterior of \(Z\)).

M-step:

For the M-step we rewrite the free energy as:

\[\begin{equation} F(q^{new}, \theta) = \mathbb{E}_{Z\sim q^{new}}\left[ \log p(X, Z\mid \theta) \right] + H[q^{new}(Z)], \end{equation}\]

where \(H[q(Z)]=\mathbb{E}_{Z\sim q}[-\log q(Z)]\) is the entropy. Holding \(q^{new}\) fixed, we maximise with respect to \(\theta\):

\[\begin{equation} \theta^{new}:=\sup_{\theta\in\Theta} \quad \mathbb{E}_{Z\sim q^{new}}\left[ \log p(X, Z\mid \theta) \right]. \end{equation}\]

Single EM iteration:

To summarise a single iteration of EM:

\[\begin{equation} F(q^{k}, \theta^{k})\le\log p(X\mid \theta^{k}) = F(q^{k+1}, \theta^{k})\le F(q^{k+1}, \theta^{k+1})\le \log p(X\mid \theta^{k+1}). \end{equation}\]

The first inequality is due to Jensen’s inequality and log being concave. The following equality is due to the E-step setting \(q^{new} (Z) = p(Z \mid X, \theta)\), the next inequality is due to the M-step maximising \(F\) wrt \(\theta\) holding \(q\) fixed, and the final inequality is again due to Jensen’s inequality.

Concluding remarks:

In summary, an iteration of EM never decreases the log likelihood. We iterate E and M steps until convergence.

It can be shown that instead of a supremum operation in the M-step, we can just take a step in a direction that increases the free energy (e.g. single gradient step) and still retain the properties of the above formulation of EM. This “partial M-step” approach is known as generalised EM (GEM).

Note: If using GEM, one should be careful with the implementation of the partial M-step, not to make a step too large that could lead further away from the coordinates of the local optima.Tuning the step size for the partial M-step might be helpful (e.g., for gradient-based partial M-step).