変分ベイズを今更理解する

趣旨

変分ベイズを理解する。

なぜ変分ベイズを?

この前EMアルゴリズムを解説した。

usapyoi.hatenablog.com

じゃあ次は変分ベイズだね。


変分ベイズはEMより圧倒的に数式を追うのが難しいので頑張りましょう。

表記法

  • \boldsymbol{X}: 観測変数(即ち入力)
  • \boldsymbol{Z}: 潜在変数

モデルパラメタ\boldsymbol{\theta}は?

変分ベイズの枠組みでは、モデルパラメタ\boldsymbol{\theta}は"未知の変数"という意味で潜在変数\boldsymbol{Z}に含まれる。ここがEMアルゴリズムと大きく違う点である。

問題設定

同時分布 p(\boldsymbol{X},\boldsymbol{Z})から、事後分布 p(\boldsymbol{Z} | \boldsymbol{X})や、モデルエビデンス p(\boldsymbol{X})を計算したい。事後分布がわかればそれを最大にする潜在変数 \boldsymbol{Z}も分かる。そして変分ベイズでは \boldsymbol{Z}にモデルパラメタが含まれるので、"最適な"モデルパラメタが \mathbb{E} \lbrack p(\boldsymbol{Z} | \boldsymbol{X}) \rbrackとしてわかるのである。

  • EMアルゴリズムでも事後分布 p(\boldsymbol{z} | \boldsymbol{x}, \boldsymbol{\theta}^{(i)})をEステップで用いている。ただしEMにおける潜在変数はモデルパラメタを含まないので、 p(\boldsymbol{z} | \boldsymbol{x}, \boldsymbol{\theta}^{(i)})を知っているからといってモデルパラメタが直ちに求まるわけではない。
  • モデルエビデンスの解釈は他サイトに譲る。知らんし

しかし、EMアルゴリズムが想定する状況ほど簡単には得られないものとする。

さて早速解法に \cdotsと行きたいところだが、前提と仮定が入り組みすぎてて面倒なので一旦後回しにする。うるせぇ早く見せろって人は第5章へどうそ。

1.  q(\boldsymbol{Z})による周辺化

 p(\boldsymbol{X},\boldsymbol{Z}) p(\boldsymbol{Z} | \boldsymbol{X}) p(\boldsymbol{X})の関係性は、EMアルゴリズムの時のように任意の分布 q(\boldsymbol{Z})を導入して、


\begin{align}
\log p(\boldsymbol{X})
 &= \log \left\lbrace \int p(\boldsymbol{X}, \boldsymbol{Z}) d \boldsymbol{Z}\right\rbrace \\\
 &= \log \left\lbrace \int q(\boldsymbol{Z}) \frac{p(\boldsymbol{X}, \boldsymbol{Z})}{q(\boldsymbol{Z})} d \boldsymbol{Z} \right\rbrace \\\
 &= \int q(\boldsymbol{Z}) \log \left\lbrace \frac{p(\boldsymbol{X}, \boldsymbol{Z})}{q(\boldsymbol{Z})} \right\rbrace d \boldsymbol{Z} - \int q(\boldsymbol{Z}) \log \left\lbrace \frac{p(\boldsymbol{Z} | \boldsymbol{X})}{q(\boldsymbol{Z})} \right\rbrace d \boldsymbol{Z}
\end{align}

と表され、さらに最右辺第一項を汎関数 \mathcal{L}(q)、第二項をKLダイバージェンス \text{KL}(q || p)で表現すると、


\begin{align}
\log p(\boldsymbol{X})
 &= \mathcal{L}(q) + \text{KL}(q || p)
\end{align}

のように非常に簡素に書ける。

  • 汎関数とは、ざっくり言えば「関数を入力とする関数」のこと。今回の例では \mathcal{L}(q)は確率分布 q(\boldsymbol{Z})を入力としている
  • KLダイバージェンスは、確率分布同士の違い(乖離度)を示す指標である。今回だと q(\boldsymbol{Z}) p(\boldsymbol{Z} | \boldsymbol{X})を比較しており、 q(\boldsymbol{Z}) = p(\boldsymbol{Z} | \boldsymbol{X})の時最小値 0を取る。調べましょう。
  • KLダイバージェンス 0以上であることが保証されている。よって、上の式から以下のように表現でき、これは以前Jensenの不等式だのなんだので導出した不等式と等価であることがわかる

\begin{align}
\log p(\boldsymbol{X}) &\geq \mathcal{L}(q)
\end{align}

さて、左辺は定数であるため、 \mathcal{L}(q)が最大の時 \text{KL}(q || p)=0、つまり q(\boldsymbol{Z})=p(\boldsymbol{Z} | \boldsymbol{X})が成立するのだが、EMアルゴリズムの時と違い p(\boldsymbol{Z} | \boldsymbol{X})が計算できない状況である。

2. 計算できないなら近似するしかない-平均場近似

変分ベイズでは、 q(\boldsymbol{Z})を以下のような式で拘束し、拘束条件下で p(\boldsymbol{Z} | \boldsymbol{X})に近づけることを考える。


\begin{align}
q(\boldsymbol{Z}) = \prod_{i=1}^M q_i(\boldsymbol{Z}_i)
\end{align}

つまり、潜在変数 \boldsymbol{Z} M個のグループ \boldsymbol{Z}_1, \boldsymbol{Z}_2, \cdots, \boldsymbol{Z}_Mに分割して、グループ間は統計的に独立であると仮定している。

  • 例えば、GMMでは混合割合 \boldsymbol{\pi}のグループや、平均 \boldsymbol{\mu}のグループ、また分散共分散行列 \boldsymbol{\Sigma}のグループがあった。グループ内はともかく、グループ間でこれらが依存しあっているとは考えにくい。

この仮定は平均場近似と呼ばれる。

それの何が嬉しいの?

平均場近似によって、 q(\boldsymbol{Z})上の \mathcal{L}(q)の最大化から各 iについての q_i(\boldsymbol{Z}_i)上の \mathcal{L}(q)の最大化へと分けることができる。以下で、 \mathcal{L}(q)の中の q_j(\boldsymbol{Z}_j)に依存する項を探してみよう(頑張りましょう)。

3.  \mathcal{L}(q)の中の q_j(\boldsymbol{Z}_j)に依存する項を探す

3-1. 重積分への変形

まず、潜在変数 \boldsymbol{Z}全体の積分から、独立関係にある潜在変数のグループ \boldsymbol{Z}_mそれぞれの積分の繰り返し、即ち重積分で表現する(あとついでに q(\boldsymbol{Z})も変換する)。


\begin{align}
\mathcal{L}(q)
 &= \int q(\boldsymbol{Z}) \log \left\lbrace \frac{p(\boldsymbol{X}, \boldsymbol{Z})}{q(\boldsymbol{Z})} \right\rbrace d \boldsymbol{Z} \\\
 &= \int \int \cdots \int \left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log \left\lbrace \frac{p(\boldsymbol{X}, \boldsymbol{Z})}{\prod_{i=1}^M q_i (\boldsymbol{Z}_i)} \right\rbrace d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
\end{align}

3-2.  \logに注目して分解する

 \logの性質より2つに分解する。見るのも嫌になるがやっていることはまだ単純。


\begin{align}
\mathcal{L}(q)
 &= \int \int \cdots \int \left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log \left\lbrace \frac{p(\boldsymbol{X}, \boldsymbol{Z})}{\prod_{i=1}^M q_i (\boldsymbol{Z}_i)} \right\rbrace d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
 &= \int \int \cdots \int \left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log p(\boldsymbol{X}, \boldsymbol{Z}) d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
 & \quad - \int \int \cdots \int \left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log \left\lbrace \prod_{i'=1}^M q_{i'} (\boldsymbol{Z}_{i'}) \right\rbrace d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
\end{align}

3-3. それぞれ見ていく

ここからは、最右辺の2つの項のそれぞれについて式変形を試みる。

3-3-1. 第一項について

3-3-1-1.  q_j(\boldsymbol{Z}_j)積分を一番外に持ってくる

積分の順序交換(数学的にそんなことして良いのかは触れない)によって、 q_j(\boldsymbol{Z}_j)積分を一番外側に持ってくる。


\begin{align}
\int \int \cdots \int &\left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log p(\boldsymbol{X}, \boldsymbol{Z}) d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
 &= \int q_j (\boldsymbol{Z}_j) \left\lbrack \int \cdots \int \left\lbrace \prod_{i \neq q} q_i (\boldsymbol{Z}_i) \right\rbrace \log p(\boldsymbol{X}, \boldsymbol{Z}) d \boldsymbol{Z}_1 \cdots d \boldsymbol{Z}_M \right\rbrack d\boldsymbol{Z}_j \\\
\end{align}

3-3-1-2.  \lbrack \cdot \rbrackの中身を簡単にする

上の式をよく見ると、 \lbrack \cdot \rbrackの中身は期待値である(ここで、あくまで \lbrace \boldsymbol{Z}_1,  \cdots , \boldsymbol{Z}_{j-1},  \boldsymbol{Z}_{j+1}, \cdots , \boldsymbol{Z}_M \rbraceにおける期待値であって、 \boldsymbol{Z}_jは含まれていないことに注意)。よって、


\begin{align}
\int \int \cdots \int &\left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log p(\boldsymbol{X}, \boldsymbol{Z}) d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
 &= \int q_j (\boldsymbol{Z}_j) \left\lbrack \int \cdots \int \left\lbrace \prod_{i \neq q} q_i (\boldsymbol{Z}_i) \right\rbrace \log p(\boldsymbol{X}, \boldsymbol{Z}) d \boldsymbol{Z}_1 \cdots d \boldsymbol{Z}_M \right\rbrack d\boldsymbol{Z}_j \\\
 &= \int q_j (\boldsymbol{Z}_j) \hspace{1mm} \mathbb{E}_{i \neq j: \boldsymbol{Z}_i \sim q_i (\boldsymbol{Z}_i)} \lbrack \log p(\boldsymbol{X}, \boldsymbol{Z}) \rbrack d\boldsymbol{Z}_j \\\
\end{align}

のように書ける。

3-3-1-3. 確率分布 \tilde{p} (\boldsymbol{X}, \boldsymbol{Z}_j)を導入

さらに状況をわかりやすくするため、以下のような確率分布 \tilde{p} (\boldsymbol{X}, \boldsymbol{Z}_j)を導入する。


\begin{align}
\tilde{p} (\boldsymbol{X}, \boldsymbol{Z}_j) = C \exp \left( \mathbb{E}_{i \neq j: \boldsymbol{Z}_i \sim q_i (\boldsymbol{Z}_i)} \lbrack \log p(\boldsymbol{X}, \boldsymbol{Z}) \rbrack \right)
\end{align}

ここで、 Cは規格化条件


\begin{align}
\int \int \tilde{p} (\boldsymbol{X}, \boldsymbol{Z}_j) d\boldsymbol{X} d\boldsymbol{Z}_j = 1
\end{align}

を守るための定数だと考えれば良い。これの対数は


\begin{align}
\log \tilde{p}(\boldsymbol{X}, \boldsymbol{Z}_j) = \mathbb{E}_{i \neq j: \boldsymbol{Z}_i \sim q_i (\boldsymbol{Z}_i)} \lbrack \log p(\boldsymbol{X}, \boldsymbol{Z}) \rbrack + \text{Const.}
\end{align}

となる。これの導入により、最終的に第一項は


\begin{align}
\int \int \cdots \int &\left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log p(\boldsymbol{X}, \boldsymbol{Z}) d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
 &= \int q_j (\boldsymbol{Z}_j) \hspace{1mm} \mathbb{E}_{i \neq j: \boldsymbol{Z}_i \sim q_i (\boldsymbol{Z}_i)} \lbrack \log p(\boldsymbol{X}, \boldsymbol{Z}) \rbrack d\boldsymbol{Z}_j \\\
 &= \int q_j (\boldsymbol{Z}_j) \log \tilde{p}(\boldsymbol{X}, \boldsymbol{Z}_j) d\boldsymbol{Z}_j + \text{Const.} \\\
\end{align}

と書けるのだった。

3-3-2. 第二項について

やることが今までとそこまで変わらないのでどんどん行きましょう。

3-3-2-1.  \logに注目して分解する


\begin{align}
\int \int \cdots \int &\left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log \left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
 &= \sum_{i'=1}^M \int \int \cdots \int \left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log q_{i'} (\boldsymbol{Z}_{i'})  d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
\end{align}

3-3-2-2.  q_j(\boldsymbol{Z}_j)積分を一番外に持ってくる


\begin{align}
\int \int \cdots \int &\left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log \left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
 &= \sum_{i'=1}^M \int \int \cdots \int \left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log q_{i'} (\boldsymbol{Z}_{i'})  d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
 &= \sum_{i'=1}^M \int q_j (\boldsymbol{Z}_j) \left\lbrack \int \cdots \int \left\lbrace \prod_{i \neq j} q_i (\boldsymbol{Z}_i) \right\rbrace \log q_{i'} (\boldsymbol{Z}_{i'})  d \boldsymbol{Z}_1 d \cdots d \boldsymbol{Z}_M \right\rbrack d \boldsymbol{Z}_j \\\
\end{align}

3-3-2-3. 定数項をまとめる

今欲しいのは q_j(\boldsymbol{Z}_j)に依存する項であってそれ以外は要らない。全部 \text{Const.}に押し込んでしまおう。


\begin{align}
\int \int \cdots \int &\left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log \left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
 &= \sum_{i'=1}^M \int q_j (\boldsymbol{Z}_j) \left\lbrack \int \cdots \int \left\lbrace \prod_{i \neq j} q_i (\boldsymbol{Z}_i) \right\rbrace \log q_{i'} (\boldsymbol{Z}_{i'})  d \boldsymbol{Z}_1 \cdots d \boldsymbol{Z}_M \right\rbrack d \boldsymbol{Z}_j \\\
 &= \int q_j (\boldsymbol{Z}_j) \left\lbrack \int \cdots \int \left\lbrace \prod_{i \neq j} q_i (\boldsymbol{Z}_i) \right\rbrace \log q_j (\boldsymbol{Z}_j)  d \boldsymbol{Z}_1 \cdots d \boldsymbol{Z}_M \right\rbrack d \boldsymbol{Z}_j + \text{Const.} \\\
\end{align}

3-3-2-4.  \lbrack \cdot \rbrackの中身を簡単にする

今度の \lbrack \cdot \rbrackの中身は \log q_j (\boldsymbol{Z}_j)である。なぜなら \log q_j (\boldsymbol{Z}_j) \lbrace \boldsymbol{Z}_1,  \cdots , \boldsymbol{Z}_{j-1},  \boldsymbol{Z}_{j+1}, \cdots , \boldsymbol{Z}_M \rbraceに依存しないからである。よって、


\begin{align}
\int \int \cdots \int &\left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log \left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace d \boldsymbol{Z}_1 \cdots d \boldsymbol{Z}_M \\\
 &= \int q_j (\boldsymbol{Z}_j) \left\lbrack \int \cdots \int \left\lbrace \prod_{i \neq j} q_i (\boldsymbol{Z}_i) \right\rbrace \log q_j (\boldsymbol{Z}_j)  d \boldsymbol{Z}_1 \cdots d \boldsymbol{Z}_M \right\rbrack d \boldsymbol{Z}_j + \text{Const.} \\\
 &= \int q_j (\boldsymbol{Z}_j) \log q_j (\boldsymbol{Z}_j) d \boldsymbol{Z}_j + \text{Const.} \\\
\end{align}

お疲れ様でした。

3-4. 結局

上の式を全部代入すれば、


\begin{align}
\therefore \mathcal{L}(q)
 &= \int \int \cdots \int \left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log p(\boldsymbol{X}, \boldsymbol{Z}) d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
 & \quad - \int \int \cdots \int \left\lbrace \prod_{i=1}^M q_i (\boldsymbol{Z}_i) \right\rbrace \log \left\lbrace \prod_{i'=1}^M q_{i'} (\boldsymbol{Z}_{i'}) \right\rbrace d \boldsymbol{Z}_1 d \boldsymbol{Z}_2 \cdots d \boldsymbol{Z}_M \\\
 &= \int q_j (\boldsymbol{Z}_j) \log \tilde{p}(\boldsymbol{X}, \boldsymbol{Z}_j) d\boldsymbol{Z}_j - \int q_j (\boldsymbol{Z}_j) \log q_j (\boldsymbol{Z}_j) d \boldsymbol{Z}_j + \text{Const.} \\\
 &= -\text{KL}(q || \tilde{p})+ \text{Const.}
\end{align}

と書けるのだった。またKLダイバージェンスである。

4. で、どうすんの

第2章で述べたように、変分ベイズにおける最適化は各 iについての q_i(\boldsymbol{Z}_i)上の \mathcal{L}(q)の最大化と言える。そして、 q_j(\boldsymbol{Z}_j)上の \mathcal{L}(q)の最大化は、上で得た結論


\begin{align}
\therefore \mathcal{L}(q)
 &= -\text{KL}(q || \tilde{p})+ \text{Const.}
\end{align}

の最大化、つまり \text{KL}(q || \tilde{p})の最小化と解釈できる。KLダイバージェンスの性質から、最小化は


\begin{align}
\log q_j^{*} (\boldsymbol{Z}_j)
 &= \log \tilde{p}(\boldsymbol{X}, \boldsymbol{Z}_j) \\\
 &= \mathbb{E}_{i \neq j: \boldsymbol{Z}_i \sim q_i (\boldsymbol{Z}_i)} \lbrack \log p(\boldsymbol{X}, \boldsymbol{Z}) \rbrack + \text{Const.}
\end{align}

の時成される。ただし、式を見れば明らかなように、 \log q_j^{*} (\boldsymbol{Z}_j)の計算式に他の潜在変数グループ \boldsymbol{Z}_i ( i \neq j)が入ってしまっており、これは完全な解析解を示せていない。

  •  \log q_j^{*} (\boldsymbol{Z}_j)を計算したい  \rightarrow  \log q_i^{*} (\boldsymbol{Z}_i)が必要だから計算したい  \rightarrow  \log q_j^{*} (\boldsymbol{Z}_j)が必要だから計算したい  \rightarrow \cdots 以下ループ

これに対処するために、まず全ての q_i (\boldsymbol{Z}_i)に初期値 q_i^{(0)} (\boldsymbol{Z}_i)を与え、反復的に更新することで解を得る。これは収束することが保証されているらしい。

5. 解法

  1. 潜在変数 \boldsymbol{Z} M個の(統計的に独立と言えそうな)グループ \boldsymbol{Z}_1, \boldsymbol{Z}_2, \cdots, \boldsymbol{Z}_Mに分割し、それぞれに確率分布 q_i (\boldsymbol{Z}_i)を定義して初期値 q_i^{(0)} (\boldsymbol{Z}_i)を与える。
  2.  j = 1, 2, \cdots, Mについて順番に、以下の式に従って更新する。

\begin{align}
\log q_j^{(t+1)} (\boldsymbol{Z}_j)
 &= \mathbb{E}_{i \neq j: \boldsymbol{Z}_i \sim q_i^{(t)} (\boldsymbol{Z}_i)} \lbrack \log p(\boldsymbol{X}, \boldsymbol{Z}) \rbrack + \text{Const.}
\end{align}

6. まとめ

変分ベイズをまとめた。変分ベイズEMアルゴリズムの違いは以下の通り

  • EMアルゴリズムはモデルパラメタ \boldsymbol{\theta}と潜在変数 \boldsymbol{Z}は別という態度をとり、事前確率 p(\boldsymbol{X} | \boldsymbol{\theta})を最大にするような \boldsymbol{\theta}を求める。
  • 変分ベイズはモデルパラメタ \boldsymbol{\theta}は潜在変数 \boldsymbol{Z}に含まれるという態度をとり、事後確率 p(\boldsymbol{Z} | \boldsymbol{X})を近似して潜在変数 \boldsymbol{Z}を求める。

特に、変分ベイズにおけるパラメタはEMアルゴリズムのように直接求まるわけではなく、あくまで事後分布の期待値、つまり \mathbb{E} \lbrack p(\boldsymbol{Z} | \boldsymbol{X}) \rbrackとして求まることに注意。

変分ベイズはGMM(またもや)などに使われる手法である。応用は頑張ってください。

7. おまけ

import numpy as np
from sklearn.mixture import BayesianGaussianMixture
model = BayesianGaussianMixture(n_components=2)
model.fit(X)

sklearn万歳