EMアルゴリズムを今更理解する

趣旨

EMアルゴリズムを理解する。

なぜEMアルゴリズム?

なんか名前がかっこいいから機械学習を学ぶ人間として当然だから。

他にも解説記事いっぱいあるけど?

表記揺れ、定義の揺れが酷すぎて比較しづらい。例えば後述のEステップについて揺れがある。


では解説する。

表記法

  • \boldsymbol{x}: 観測変数(即ち入力)
  • \boldsymbol{z}: 潜在変数(観測されない変数 欠損データと呼ばれたりするが、実在性は特に重要ではない)
  • \boldsymbol{\theta}: モデルパラメタ

問題設定

最尤推定がしたい。つまり、事前確率 p(\boldsymbol{x} | \boldsymbol{\theta})を最大にするような \boldsymbol{\theta}を求めたい。


\boldsymbol{\theta}_{\text{ML}} \leftarrow \max_{\boldsymbol{\theta}} \log p(\boldsymbol{x} | \boldsymbol{\theta})

しかし、直接偏微分して解くことが難しく、代わりに p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})は陽に得られている状況を考える。

解法

 \boldsymbol{\theta}に適当な初期値 \boldsymbol{\theta}^{(0)}を与えた後、以下のEステップ、Mステップを収束するまで繰り返す。

Eステップ

まず以下のような事後確率 q(\boldsymbol{z})を計算する。


q(\boldsymbol{z}) = p(\boldsymbol{z} | \boldsymbol{x}, \boldsymbol{\theta}^{(i)})

次に先ほど計算した q(\boldsymbol{z})を用いて、以下のような期待値を Q^{(i)}(\boldsymbol{\theta})とする。


\begin{align}
Q^{(i)}(\boldsymbol{\theta})
 & = \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z})} \left\lbrack \log p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}) \right\rbrack \\\
 & = \int q(\boldsymbol{z}) \log p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}) d \boldsymbol{z} \\\
\end{align}

期待値をとってるからE(xpectation)ステップ。

Mステップ

先ほど計算した Q^{(i)}(\boldsymbol{\theta}) \boldsymbol{\theta}について最大化し解を \boldsymbol{\theta}^{(i+1)}とする。


 \boldsymbol{\theta}^{(i+1)} \leftarrow \text{argmax}_{\boldsymbol{\theta}} Q^{(i)}(\boldsymbol{\theta})

最大化しているからM(aximization)ステップ。

なんでうまくいくの?

細かい計算は飛ばして、ざっくり概要を述べる。めんどくさいし。

Eステップ

1. 潜在変数による周辺化

目的関数である \log p(\boldsymbol{x} | \boldsymbol{\theta})を潜在変数\boldsymbol{z}についての周辺化によって表現する。


\log p(\boldsymbol{x} | \boldsymbol{\theta}) = \log \int p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}) d \boldsymbol{z}

2.  q(\boldsymbol{z})を導入

任意の確率分布 q(\boldsymbol{z})を導入し、以下のように式変形する。


\log \int p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}) d \boldsymbol{z} = \log \int q(\boldsymbol{z}) \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} d \boldsymbol{z}

3. Jensenの不等式を導入

3-1. Jensenの不等式って?

関数 fが(下に)凸であるとき、


\begin{align}
u(x) \geq 0, &\quad \int u(x) dx = 1 \\\
&\Rightarrow \int u(x) f(v(x)) dx \geq f(\int u(x) v(x) dx)
\end{align}

が成立し、等号成立は以下の2種類のいずれかを満足すると成立する。

  1.  v(x) = m ( xの値に関わらず一定)
  2.  f(x)が原点を通る直線である(まず満足しない)

これをJensen(イェンゼン、イェンセン)の不等式と言う。なんでこれが成り立つのかは他のサイトに譲る。ともかくこれを適用することを試みる。

3-2. さっきの式に適用する

具体的には以下の右辺。


\log \int p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}) d \boldsymbol{z} = \log \int q(\boldsymbol{z}) \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} d \boldsymbol{z}

関数 f \logであって欲しい。残念ながら \logは凸は凸でも上に凸だ。しかし逆に考えれば -\logとすれば下に凸になってくれる。よって、


\begin{align}
f(\cdot) &\rightarrow -\log (\cdot) \\\
u(x) &\rightarrow q(\boldsymbol{\theta}) \\\
v(x) &\rightarrow \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} \\\
\end{align}

のような対応関係を結べば、


f(\int u(x) v(x) dx) \rightarrow -\log \int q(\boldsymbol{z}) \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} d \boldsymbol{z}

と分かり、不等式は


\log \int q(\boldsymbol{z}) \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} d \boldsymbol{z} \geq \int q(\boldsymbol{z}) \log \left( \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} \right) d \boldsymbol{z}

と書ける(不等号の向きに注意せよ)。

3-3. 等号成立条件

では等号成立はどの時だろうか。Jensenによれば、それは v(x) = mの時である( \logが直線な訳ないので)。即ち、


\frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} = m

\therefore p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}) = m q(\boldsymbol{z})

で、 mっていくらなの? それは \boldsymbol{z}について周辺化してやれば分かる。


\begin{align}
\int p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}) d\boldsymbol{z} &= \int m q(\boldsymbol{z}) d\boldsymbol{z} \\\
p(\boldsymbol{x} | \boldsymbol{\theta}) &= m \\\
\end{align}

と言うわけで、 m=p(\boldsymbol{x} | \boldsymbol{\theta})であり、等号成立条件は


q(\boldsymbol{z}) = \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{m}  = \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{p(\boldsymbol{x} | \boldsymbol{\theta})} = p(\boldsymbol{z} | \boldsymbol{x}, \boldsymbol{\theta})

と書き表せるのだった。

3-4. 等号が成立したら何が嬉しいの?

先ほどの不等式


\log \int q(\boldsymbol{z}) \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} d \boldsymbol{z} \geq \int q(\boldsymbol{z}) \log \left( \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} \right) d \boldsymbol{z}

において左辺は本来は q(\boldsymbol{z})に依存しない値であるから、 q(\boldsymbol{z})をいじって等号を成立させることは右辺(下限)の最大化を意味する。よって、目的関数の値は変わってないが、それの下限の値が最大化され、続くMステップにおいて目的関数の値を増加させるのに貢献するのである。

3-5. Eステップとやってること違わない?

右辺について \logに着目し式展開を行うと、


\begin{align}
\int q(\boldsymbol{z}) \log &\left( \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} \right) d \boldsymbol{z} \\\
 &= \int q(\boldsymbol{z}) \log p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}) d \boldsymbol{z} + \text{Const.} \\\
 &= \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z})} \left\lbrack \log p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}) \right\rbrack + \text{Const.} \\\
\end{align}

と書ける。よって、

  1.  q(\boldsymbol{z}) = p(\boldsymbol{z} | \boldsymbol{x}, \boldsymbol{\theta})とした上で \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z})} \left\lbrack \log p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}) \right\rbrackを求めることは
  2.  \int q(\boldsymbol{z}) \log \left( \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} \right) d \boldsymbol{z}の最大値を求めること、ひいては
  3.  \log p(\boldsymbol{x} | \boldsymbol{\theta})の(あくまで)下限の q(\boldsymbol{z})における最大値を求めることに等しい。

3-6. 厳密には

実際のEステップでは、 \boldsymbol{\theta}の値として \boldsymbol{\theta}^{(i)}が得られている。よって、先ほどまでの議論は実は \log p(\boldsymbol{x} | \boldsymbol{\theta}^{(i)})の下限を最大化していたのだった。ではなぜそう書かないのか。おそらくだが、続くMステップで \boldsymbol{\theta}の値をいじるため、定数 \boldsymbol{\theta}^{(i)}ではなく変数 \boldsymbol{\theta}として扱うことを強調したかったのだろう。

Mステップ

先のEステップにおいて、 \log p(\boldsymbol{x} | \boldsymbol{\theta})の下限を q(\boldsymbol{z})について最大化した。Mステップでは、その下限を今度は \boldsymbol{\theta}について最大化している。

Jensenの不等式における右辺はEステップによって


\begin{align}
\int q(\boldsymbol{z}) \log &\left( \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} \right) d \boldsymbol{z} \\\
 &= \mathbb{E}_{\boldsymbol{z} \sim q(\boldsymbol{z})} \left\lbrack \log p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}) \right\rbrack + \text{Const.} \\\
 &= Q^{(i)}(\boldsymbol{\theta}) + \text{Const.} \\\
\end{align}

のように表記できる。この式の最右辺を \boldsymbol{\theta}について最大化することは、 \log p(\boldsymbol{x} | \boldsymbol{\theta})の(あくまで)下限を最大化することと等価である。しかしここで問題が発生する。

等号成立の崩壊

Eステップで導入したJensenの不等式において、等号成立条件を


q(\boldsymbol{z}) = p(\boldsymbol{z} | \boldsymbol{x}, \boldsymbol{\theta})

と書いた。そして、実際は定数 \boldsymbol{\theta}^{(i)}を用いて、


q(\boldsymbol{z}) = p(\boldsymbol{z} | \boldsymbol{x}, \boldsymbol{\theta}^{(i)})

のように計算した。しかし、 \boldsymbol{\theta}について最大化、つまり \boldsymbol{\theta}の値をいじって \boldsymbol{\theta}^{(i+1)}としたとき、


q(\boldsymbol{z}) \neq p(\boldsymbol{z} | \boldsymbol{x}, \boldsymbol{\theta}^{(i+1)})

となってしまう。要するに、旧来の( \boldsymbol{\theta}^{(i)}を用いた) q(\boldsymbol{z})のままでは、不等式


\log \int q(\boldsymbol{z}) \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}^{(i+1)})}{q(\boldsymbol{z})} d \boldsymbol{z} \geq \int q(\boldsymbol{z}) \log \left( \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta}^{(i+1)})}{q(\boldsymbol{z})} \right) d \boldsymbol{z}

の等号成立条件を満たさなくなってしまう。これは何を意味するのか?

等号成立の崩壊は進歩を意味する

Mステップでやっていることは、等号が成立していた不等式


\log \int q(\boldsymbol{z}) \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} d \boldsymbol{z} \geq \int q(\boldsymbol{z}) \log \left( \frac{p(\boldsymbol{x}, \boldsymbol{z} | \boldsymbol{\theta})}{q(\boldsymbol{z})} \right) d \boldsymbol{z}

の右辺を最大化することで等号を不成立にすることである。最大化しているのだから、右辺はもちろん増加する。だが、等号が不成立なのだから、左辺はそれ以上に増加するのである。左辺というのはすなわち目的関数なので、目的関数の値が増加する結果となるのである。

Eステップに戻る

等号が不成立になっちゃったのなら、また q(\boldsymbol{z})をいじって成立させればいいじゃない、そしてまた最大化して不成立にさせればいいじゃない。それを繰り返せばいつかは目的関数の値は上限に到達し、 \boldsymbol{\theta}^{(i)}は最尤解 \boldsymbol{\theta}_{\text{ML}}に限りなく近づくだろう。

結局

EMアルゴリズムは、 \log p(\boldsymbol{x} | \boldsymbol{\theta})の下限を

  1.  q(\boldsymbol{z})について最大化(目的関数 =下限)
  2.  \boldsymbol{\theta}について最大化(目的関数 >下限になり結果的に目的関数の値が増加)

し続けるアルゴリズムなのである。終わり。


具体的にはどこで使われてるの?

GMM(混合ガウス分布)があまりにも有名だが、IRT(項目反応理論)でも用いられていたりする。予想しないタイミングでひょっこり顔を出すので注意。

GMMだとどういう更新則になるの?

知らん。その辺のサイトに転がっているので確認せよ。書いたよ。

usapyoi.hatenablog.com

まとめ

EMアルゴリズムを私なりに説明した。難解なことをやっているようで(実際しっかり理解しようとすると相当厄介なのだが)実際はそこまででもない。KLダイバージェンスの説明や行間の式変形、また理論的に解が得られる保証を省略したので、概要を掴んだ後はより詳しいサイトに行ってみよう。

おまけ

変分ベイズとかもかっこいいよね。響きが。