1 The EM algorithm

1 The EM algorithm

In this set of notes, we discuss the EM (Expectation-Maximization) algorithm, which is a common algorithm used in statistical estimation to try and find the MLE.

It is often used in situations that are not exponential families, but are derived from exponential families. A common mechanism by which these likelihoods are derived is through missing data, i.e. we only observe some of the sufficient statistics of the family.

1.1 Mixture model

A canonical application of the EM algorithm is its use in fitting a mixture model where we assume we observe an IID sample of (Xi)1in from

Y Multinomial(1, ), X|Y = l Pl

RL

with the simplest example of P being the univariate normal model

Pl = N (?l, l2)

keeping in mind that the parameters on the right are the mean space parameters, not the natural parameters.

1.1.1 Exercise

1. Show that the joint distribution of (X, Y ) is an exponential family. What is its reference measure, its sufficient statistics? Write out the log-likelihood based on observing an IID sample (Xi, Yi)1in for this model. Call this c(; X, Y ) the complete likelihood.

2. What is the marginal density of X?

3. Write out the log-likelihood (; X) based on observing an IID sample (Xi)1in from this model. What are its parameters?

In the mixture model, we only observe X, though the marginal distribution of X is the same as if we had generated pairs (X, Y ) and marginalized over Y . In this problem, Y is missing data which we might call M , and X is observed data which we might call O. Formally, then, we partition our sufficient statistic into two sets: those observed, and those missing.

1.2 The EM algorithm

The EM algorithm usually has two steps, both of which are based on the following function Q(; ~) = E~ c(; O, M ) O

The basis of the EM algorithm is the following result: Q(; ~) Q(~; ~) = (; O) (~; O).

1

Therefore, any sequence ((k))k1 satisfying

Q((k+1); (k)) Q((k); (k))

has ((k); O) non-decreasing. An algorithm that produces such a sequence is called a GEM algo-

rithm (generalized EM algorithm).

The proof of this is fairly straightforward after some initial slight of hand. After this slight of

hand, we see the main ingredient in the proof is deviance of the conditional distribution of M |O.

In the general case, this deviance is not expressed in terms natural parameters but the argument is

the same.

Here is the proof: writing the joint distribution of (O, M ) (assuming it has a density with respect

to P0) as

dP dP0

=

f,(O,M)(o, m)

=

f,O(o) ? f,M|O(m|o)

where the f 's are densities with respect to P0. Or,

f,O (o)

=

f,(O,M)(o, m) . f,M |O (m|o)

Although the RHS seems to depend on m, the above equality shows that it is actually measurable with respect to o.

We see that

n

(; O) = log f(Oi)

i=1 n

= [log f(Oi, Mi) - log f(Mi|Oi)]

i=1

n

= [log f(Oi, Mi) - log f(Mi|Oi)]

i=1

where we know that f(m|o) is an exponential family for O fixed. The right hand side is measurable with respect to O so its conditional expectation with respect

2

to O leaves it unchanged. Therefore, for any ~ we have the equality

n

(; O) = log f(Oi)

i=1 n

= [log f(Oi, Mi) - log f(Mi|Oi)]

i=1

Now,

n

= E~ log f(Oi, Mi) O - E~ log f(Mi|Oi) O

i=1 n

= E~ c(; O, M ) O - E~ log f(Mi|Oi) Oi

i=1 n

= Q(; ~) - E~ log f(Mi|Oi) Oi .

i=1

(; O) - (~; O) = Q(; ~) - Q(~; ~)

n

+

E~ log f~(Mi|Oi) Oi - E~ log f(Mi|Oi) Oi

i=1

The term

n

E~ log f~(Mi|Oi) O - E~ log f(Mi|Oi) O

i=1

is essentially half the deviance of the exponential family of conditional distributions for M |O with

sufficient statistics M . To see this, recall our general form of the conditional density of T1|T2 = s2 for an Rp valued

sufficient statistic partitioned as T1 Rk, T2 Rp-k:

fT1|T2=s2 (t1) = = =

fT1,T2 (t1, s2)

Rk fT1,T2 (s1, s2) ds1 e1T t1+2T s2 m~ 0(t1, s2)

Rk e1T s1+2T s2 m~ 0(s1, s2) ds1 e1T t1 m~ 0(t1, s2)

Rk e1T s1 m~ 0(s1, s2) ds1

Therefore, with C a function independent of

log f(Mi|Oi) = MT Mi - log

eM T sm~ 0(s, Oi) ds

Rk

= MT Mi - ~ (M , Oi) + C(Mi, Oi)

+ C(Mi, Oi)

where ~(M , Oi) is the appropriate CGF for this conditional distribution.

3

We see then, that

log f~(Mi|Oi) - log f(Mi|Oi) = ~ (M , Oi) - ~ (~M , Oi) - (M - ~M )T Mi.

Taking conditional expectation with respect to O yields at ~ yields

n

1

E~

log f~(Mi|Oi) - log f(Mi|Oi) O

= D(~; |O) 0. 2

i=1

1.3 The two basic steps

The algorithm is often described as having two steps the E step and the M step. Formally, the E step can be described as evaluating Q(; ~) with ~ fixed. That is, fix ~ and compute

q~() = E~ c(; O, M ) O

as a function of . The M is the maximization step and amounts to finding

^(~) argmax Q(; ~) = argmax q~().

1.4 EM algorithm for exponential families

The EM algorithm for exponential families takes a particularly nice form when the MLE map is nice in the complete data problem. Expressed sequentially, it can be expressed by the recursion

^(k+1) = argmax T E(k) ((M, O)|O) - () .

In other words, we need to form the conditional expectation of all the sufficient statistics given the sufficient statistics we did observe. Following this, we just return the MLE as if we had observed those sufficient statistics.

Another way to phrase this is

^(k+1) = E(k) ((M, O)|O)

1.5 Mixture model example

In the mixture model, if we write Yi = (Yi1, . . . , YiL) example the sufficient statistics can be taken

to be

n

n

n

t(X, Y ) =

Yij , Yij Xi, Yij Xi2

.

i=1

i=1

i=1

1jL

where only

L j=1

Yij Xi

=

Xi,

1

i

n

is

observed.

4

1.5.1 Exercise

Use Bayes rule to show that, in our univariat e normal mixture model

P(Y = l|X = x) =

l(x, ?l, l2)

L j=1

j

(x,

?j

,

l2)

where (x, ?, 2) is the univariate density of N (?, l2). If we set ^l(x, ~) = P~(Y = l|X = x)

The above exercise shows that

n

E~

YilXi X

i=1

n

E~

YilXi2 X

i=1

n

E~

Yil X

i=1

n

= ^l(Xi, ~)Xi

i=1 n

= ^l(Xi, ~)Xi2

i=1 n

= ^l(Xi, ~)

i=1

The usual MLE map (for the mean parameters) in this model can be expressed as

n

^l = Yil/n

i=1

?^l =

n i=1

YilXi

n i=1

Yil

^l2 =

n i=1

Yil(Xi

-

?^l)2

n i=1

Yil

=

n i=1

YilXi2

n i=1

Yil

-

n i=1

Yil

Xi

2

n i=1

Yil

This leads to the algorithm, given an initial set of parameters (0) we repeat the following updates for k 0:

? Form the responsibilities ^l(Xi; (k)), 1 l L, 1 i n.

? Compute

n

^l(k+1) = ^l(Xi; (k))/n

i=1

?^(lk+1) =

n i=1

^l(Xi;

(k))Xi

n i=1

^l

(Xi;

(k)

)

^2(lk+1) =

n i=1

^l(Xi;

(k))Xi2

n i=1

^l(Xi;

(k))

-

?^(lk+1) 2

? Repeat

5

Let's test out our algorithm on some data from the mixture model. mu1, sigma1 = 2, 1 mu2, sigma2 = -1, 0.8 X1 = np.random.standard_normal(200)*sigma1 + mu1 X2 = np.random.standard_normal(600)*sigma2 + mu2 X = np.hstack([X1,X2]) %R -i X plot(density(X))

def phi(x, mu, sigma): """ Normal density """ return np.exp(-(x-mu)**2 / (2 * sigma**2)) / np.sqrt(2 * np.pi * sigma**2)

def responsibilities(X, params): """ Compute the responsibilites, as well as the likelihood at the same time. """ mu1, mu2, sigma1, sigma2, pi1, pi2 = params 6

gamma1 = phi(X, mu1, sigma1) * pi1 gamma2 = phi(X, mu2, sigma2) * pi2 denom = gamma1 + gamma2 gamma1 /= denom gamma2 /= denom return np.array([gamma1, gamma2]).T, np.log(denom).sum() mu1, mu2, sigma1, sigma2, pi1, pi2 = 0, 1, 1, 4, 0.5, 0.5 gamma, likelihood = responsibilities(X, (mu1, mu2, sigma1, sigma2, pi1, pi2)) Here is our recursive estimation procedure, which is fairly straightforward here. niter = 20 n = X.shape[0] values = [] for _ in range(niter): gamma, likelihood = responsibilities(X, (mu1, mu2, sigma1, sigma2, pi1, pi2)) pi1, pi2 = gamma.sum(0) / n mu1 = (gamma[:,0] * X).sum() / (pi1*n) mu2 = (gamma[:,1] * X).sum() / (pi2*n) sigma1_sq = (gamma[:,0] * X**2).sum() / (n*pi1) - mu1**2 sigma2_sq = (gamma[:,1] * X**2).sum() / (n*pi2) - mu2**2 sigma1 = np.sqrt(sigma1_sq) sigma2 = np.sqrt(sigma2_sq) values.append(likelihood) We can track the value of the likelihood and, since we have an EM algorithm, the likelihood should be monotone with iterations. plt.plot(values) plt.gca().set_ylabel(r'$\ell^{(k)}$') plt.gca().set_xlabel(r'Iteration $k$')

7

Let's plot our density estimate to see how well the mixture model was fit. %%R -i pi1,pi2,sigma1,sigma2,mu1,mu2 X = sort(X) plot(X, pi1*dnorm(X,mu1,sigma1)+pi2*dnorm(X,mu2,sigma2), col='red', lwd=2, type='l',

ylab='Density') lines(density(X))

8

................
................

In order to avoid copyright disputes, this page is only a partial summary.

Google Online Preview   Download