【徹底解説】変分ベイズをはじめからていねいに

本記事は機械学習の徹底解説シリーズに含まれます。

初学者の分かりやすさを優先するため,多少正確でない表現が混在することがあります。もし致命的な間違いがあればご指摘いただけると助かります。

はじめに

機械学習を勉強したことのある方であれば,変分ベイズ(VB:variational bayes)の難しさには辟易したことがあるでしょう。私自身,学部生時代に意気揚々と機械学習のバイブルと言われている「パターン認識と機械学習(通称PRML)」を手に取って中身をペラペラめくってみたのですが,あまりの難しさから途方に暮れてしまったことを覚えています。

機械学習の登竜門は,変分ベイズ(変分推論)だと私は考えています。また,VAE(変分オートエンコーダ;variational autoencoder)に代表されるように,変分ベイズは最近の深層学習ブームにおいて理論面の立役者となっている側面もあります。一方で,多くの書籍やWeb上の資料では式変形の行間が詰まっていないことがあるため,初学者は必ず変分ベイズで躓くと言っても過言ではありません。

この問題を解決するため,本稿では変分ベイズをはじめからていねいに説明していきます。具体的には「この解説を読んだだけで変分ベイズの概要と実際の応用方法が理解できる」状態を目指します。多少記事が長くなってしまいますが,ゆっくり自分のペースで読み進めていけば,必ず変分ベイズを理解できるはずです。

一般に,難しい概念を噛み砕いて説明するときには,ボトムアップ的に必要となる知識を武装していく方法と,トップダウン的に求められる知識に寄り道していく方法があります。本稿では,両者を組み合わせて説明していきます。最初に,変分ベイズの目的をお伝えします。その上で,必要となる知識をボトムアップ的に積み上げていくという方針を採用します。

STEP
変分ベイズの目的

変分ベイズがどのような枠組みで何のために用いられるのかを説明します。

STEP
各種推定方法の実現

最尤推定・MAP推定・ベイズ推論を実現するための枠組みについて説明します。

STEP
混合ガウス分布への適用

混合ガウス分布を例にとってEMアルゴリズムと変分ベイズの使い方を確認します。

STEP
実装

pythonを使ってEMアルゴリズムと変分ベイズを実装します。

変分ベイズの目的

本章では,変分ベイズがどのような目的で用いられるのかを説明します。先に結論からお伝えすると,変分ベイズは確率モデルの潜在変数・パラメータに関する事後分布を近似するための手法です。そこで,まず最初に確率モデルと事後分布に関する説明から始めていきます。

確率モデルというのは「現象の裏側に何か適当な分布を仮定する」枠組みのことです。私たちの目的は,ある現象を確率分布を用いて記述することです。そのためには,以下のステップが必要になります。

  • ある現象をよく観察して最もよくフィットする既存の確率分布を選択する
  • 仮定した確率分布の形状を決定するパラメータを推定する

ある現象に対して既存の分布を仮定するという操作は,観測データの発生方法に対して尤度関数を定めることに相当します。すると,私たちの最終的な目的は分布の形状を決定するパラメータを推定することになります。

尤度関数と同時にパラメータの情報を与える事前分布を設定すると,同時分布を定めることに相当します。ここでは,ある現象に対する既存の分布として尤度関数を定めていますが,ベイズ推定を行う場合には既存の分布として同時分布を定めます。

パラメータの推定方法は,パラメータを1つの値に決め切ってしまう点推定と,パラメータ自体にも既存の分布を仮定するベイズ推定に分けられます。さらに,点推定は尤度関数を最大にする最尤推定と事後確率を最大にする最大事後確率(MAP: maximum a posteriori)推定に分けられます。

  • 点推定
    • 最尤推定
    • MAP推定
  • ベイズ推定

以下では,数式を用いて点推定とベイズ推定を説明していきます。

点推定

あるパラメータθが与えられたときに,観測データXのふるまいを定義する関数p(X|θ)が尤度関数です。最尤推定では,尤度関数p(X|θ)を最大にするθの値を求めます。

(1)θ^ML=arg maxθ p(X|θ)

条件付き確率の定義より,p(X|θ)は「θが与えられたときのX」の確率を表します。言い換えれば「θが与えられたときにXはどれだけ尤もらしいか」を表しています。最尤推定では,条件付けられたθを変数として読み替えることで,データXに対して最も尤もらしいパラメータθを推定します。

p(X|θ)を関数と呼ぶことに違和感がある人は鋭い視点をお持ちです。測度論を用いると,確率は可測空間上の確率測度(つまり関数)として定義されます。測度論を用いない場合の確率pの定義は,こちらからご確認ください。

一方で,MAP推定では事後確率p(θ|X)を最大にするθの値を求めます。

(2)θ^MAP=arg maxθ p(θ|X)(3)=arg maxθ p(X|θ)p(θ)p(X)(4)=arg maxθ p(X|θ)p(θ)

ただし,変形にはベイズの定理を利用しました。θに関する最大化問題であるため,最終行ではp(X)を無視しました。

計算を簡単にするため,最尤推定やMAP推定では両辺の対数を取った対数尤度関数や対数事後確率を最大化することが多いです。対数関数は単調増加関数であるため,対数を取る前後の最大最小問題は等価になります。

(5)θ^ML=arg maxθ lnp(X|θ)(6)θ^MAP=arg maxθ {lnp(X|θ)+lnp(θ)}

確率モデルをうまく設定できれば,これらの点推定は解析的に解くことも可能です。例えば,正規分布の形状を決定する母平均の最尤推定量は標本平均に一致し,母分散の最尤推定量は標本分散に一致します。これらはラグランジュの未定乗数法を用いて証明されます。

ベイズ推定

ベイズ推定では,点推定とは異なりパラメータを値ではなく分布として求めることで,データへの過学習を防止したり,表現力を上げたりすることができます。パラメータの分布というのは,ある観測データXが得られたときの事後分布p(θ|X)のことを表しています。

(7)p(θ|X)=p(X|θ)p(θ)p(X)

点推定とは異なり,ベイズ推定ではargmaxが付いていませんね。argmaxargminは目的関数が最大・最小になるパラメータを値として求めることを意味しています。耳が痛いようですが,ベイズ推定では仮定した確率分布pの形状を決めるパラメータθの事後分布p(θ|X)を求めます。

点推定の限界として,対象とする現象が多峰性のふるまいを示しているケースが挙げられます。多峰性とは,複数のピークをもつ性質のことを指します。多峰性の現象に対して点推定を行うと,1つのピーク以外の情報を完全に捨て去ってしまいます。

各種推定方法の実現

本章では,各種パラメータの推定方法がどのように実現されるのかについて説明します。先に結論からお伝えすると,点推定を実現する方法はEMアルゴリズム,ベイズ推定を実現する方法は変分ベイズがよく利用されます。

  • 点推定
    • EMアルゴリズム
  • ベイズ推定
    • 変分ベイズ(変分推論)

ちなみに,変分ベイズは文脈によっては「変分推論」とも呼ばれます。ベイズモデリングであることを強調したい場合は「変分ベイズ」と呼ばれることが多いようです。

確率モデルの問題を考えるときには,まず最初に確率モデルに潜在変数Z(尤度関数に姿を現さない変数)を導入します。そうすることで,複雑な分布をより単純な分布を使って表すことが可能になります。潜在変数を導入することは,同時分布p(X,Z)を設計することに相当します。ここが全ての始まりです。観測データXと潜在変数Zにどのような依存関係があるのかを同時分布として設定してしまうわけです。

ベイズ推定では,副次的に尤度関数p(X|Z)と事前分布p(Z)を仮定することになります。

同時分布の設定

尤度関数と事前分布を定めることは,同時分布を定めることだけでなく,事後分布の形を定めることにも相当します。なぜなら,Xが観測データであることからp(X)は定数だからです。

(8)p(Z|X)=p(X,Z)p(X)(9)p(X,Z)(10)=p(X|Z)p(Z)

点推定では,事前分布が無情報(定数)だと仮定します。結局,上でお伝えしたように,点推定では同時分布を設計することは尤度関数を設定することに相当します。

(11)p(X,Z)=p(X|Z)p(Z)(12)=p(X|Z)const(13)p(X|Z)

ベイズ推定では,事前分布には共役事前分布を仮定するケースが多いです。共役事前分布を設定すると,事後分布は事前分布と同じ形になりますので,圧倒的に計算が簡単になります。一方で,全ての尤度関数に対して共役事前分布が存在する訳ではなく,以下の分布対が利用されることが多いです。

代表的な共役事前分布

変分ベイズの本領が発揮されるのは,共役事前分布を定められないとき,もしくは事前分布が部分的な共役性しか持たないときです。今回扱う混合ガウス分布も,部分的な共役性をもつ確率モデルの一例ですので,変分ベイズの本領が発揮されることになります。

EMアルゴリズム

結論からお伝えすると,EMアルゴリズムは以下の流れで計算されます。

EMアルゴリズムのアニメーション

詳しい導出方法は「EMアルゴリズムをはじめからていねいに」をご参照下さい。

変分ベイズ

変分ベイズの目的は,確率モデルの潜在変数・パラメータに関する事後分布を求めることでした。EMアルゴリズムでは潜在変数をZ,パラメータをθと区別しましたが,変分ベイズでは両者を一括りにしてZと表しますので,事後分布はp(Z|X)と表されます。EMアルゴリズムではp(Z|X)を計算できるという立場を取りましたが,変分ベイズではp(Z|X)を計算できないという立場を取ります。

変分ベイズは,事後分布p(Z|X)を別の新しい分布q(Z)で近似してしまおうという大胆かつ汎用性の高い手法です。

(14)q(Z)p(Z|X)

この出発点こそ,変分ベイズが変分法の一種であることを裏付けています。変分法とは,ある関数の最適化問題を解く手法の一つです。今回は「真の事後分布を別の新しい分布q(Z)で近似する」という文脈で変分法を用いています。語弊を恐れずに言うならば,変分法は微分法の拡張です。微分法ではN次元ユークリッド空間における臨界点(微分が0になって関数が平らになる部分)を求めるのに対し,変分法では無限次元関数空間における臨界点を求めます。

結論からお伝えすると,変分ベイズでは以下の更新式を利用して潜在変数・パラメータに関する近似事後分布q(Z)を求めます。

(15)lnqi(Zi)=Eji[lnp(X,Z)]+const

qZの添え字であるijの意味については後ほどお伝えします。

更新式の導出には2通りの方法があります。1つ目は,事後分布とのKLダイバージェンスを最小化するような分布を求める方法です。2つ目は,EMアルゴリズムと同じ枠組みで周辺対数尤度関数を下限とKLダイバージェンスの2つに分解して考える方法です。前者は直感的に理解しやすい方法であり,後者はEMアルゴリズムと数学的な背景を一貫させて理解することができる方法です。それぞれ詳しく見ていきましょう。

KL最小化による近似

私たちの目的は,p(Z|X)をよく近似するq(Z)を求めることでした。2つの分布間の距離を測るオーソドックスな指標としては,KLダイバージェンスが利用できます。指標としてKLダイバージェンスを用いる理由は,必然性を理論的に説明できるからです。この必然性については次の「EMアルゴリズムの類推」で説明します。

2つの分布間の距離を測る指標としてKLダイバージェンスを採用することを受け入れると,p(Z|X)とのKLダイバージェンスが小さくなるようなq(Z)を求めれば良いことになります。EMアルゴリズムでは,真の事後分布を計算できるという立場を取るため,KLダイバージェンスを0にする近似事後分布を求めることができました。一方で,変分ベイズでは真の事後分布を計算することができないという立場を取るため,近似分布が真の事後分布と厳密に等しくなることはないと仮定します。

しかし,何の制限もない中でp(Z|X)とのKLダイバージェンスが小さくなるようなq(Z)を求めるのは自由度が高すぎて困難です。そこで,変分ベイズでは「平均場近似」と呼ばれる仮定を採用します。

(16)q(Z)=i=1Mqi(Zi)

平均場近似は,事後分布p(Z|X)をよく近似するq(Z)は独立な分布qi(Zi)の積で表されるという強い仮定を表しています。ただし,必ずしも全ての要素が独立だという仮定ではないことに注意してください。あくまでも,どのような要素に分解するかの「グループ分け」が平均場近似の仮定だと認識しておきましょう。

すると,グループ分けした潜在変数・パラメータのインデックスをiとおいたときに,求めるqi(Zi)というのは以下の解になります。ただし,qi(Zi)のことをqiと省略して書きます。

(17)arg minqi KL[i=1Mqi(Zi)p(Z|X)]

KLダイバージェンスは非対称ですので,KL(pq)KL(qp)は異なります。前者をForward KL,後者をReverse KLと呼ぶことがあります。KLダイバージェンスの定義より,Forward KLはlog関数をpで重みづけしていますので,p0の部分をqで網羅しようとします。その結果,qの分散は大きくなりやすいです。一方で,Reverse KLはlog関数をqで重みづけしていますので,q0ではない部分でpを網羅しようとします。その結果,qの分散は小さくなりやすいです。例えば,真の事後分布が多峰性の場合,Forward KLでは複数のピークをならしたような近似事後分布が得られるのに対し,Reverse KLではある1つのピークに着目した近似事後分布が得られやすいです。多峰性の場合は近似事後分布に混合ガウス分布を持ち出せばうまくフィッティングできますが,KLダイバージェンスの非対称性を把握しておくことは非常に大切です。

実際に計算していきます。KLダイバージェンスの定義より,以下のように計算することができます。

(18)qi(Zi)=arg minqi KL[i=1Mqi(Zi)p(Z|X)](19)=arg minqi Eq[lni=1Mqi(Zi)p(Z|X)]

今求めたいのはZiに対する近似事後分布ですので,期待値を取る際にはqi(Zi)に対する期待値とqji(Zji)に対する期待値を分けて考えてあげます。qjiは「iとは異なるj」と読むと理解しやすいです。

(20)qi(Zi)=arg minqi Eq[lnqi(Zi)qji(Zji)p(Z|X)](21)=arg minqi Eqi[Eqji[lnqi(Zi)qji(Zji)p(Z|X)]](22)=arg minqi {Eqi[Eqji[lnqi(Zi)]+Eqji[lnqji(Zji)]Eqji[lnp(Z|X)]]}

q(Z)ijiに分けるというアイディアが少し奇抜に感じられるかもしれません。これは後に,KLダイバージェンスを0にするq(Z)を見つけるための苦肉の策です。平均場近似は強い仮定であることから「各因子ごとであればKLダイバージェンスを0にする事後分布を計算できるのではないか」というアイディアに着想した結果です。実際に計算を進めていくと,式(32)でKLダイバージェンスを0にするqi(Zi)を導出できることが分かります。

さて,ここで期待値に関する以下の性質を利用します。

  • 1の期待値は1

(23)EX[1]=1

  • 期待値の対象と中身が独立な(定数とみなせる)場合に期待値の外に出せる [参考]

(24)EX[aX+b]=aEX[X]+b

  • 確率変数の過不足なく期待値を取った結果はスカラー

(25)EY[EX[XY]]=EY[μxY]=μxEY[Y]=μxμy

最後の項目について少し補足しておきます。期待値の定義注目している確率変数に対する周辺化操作に相当しますので,確率変数に過不足なく期待値を取った結果はスカラーになります。

以上を踏まえると,以下が成り立ちます。

(26)Eqji[lnqi(Zi)]=lnqi(Zi)(27)Eqji[lnqji(Zji)]=const

したがって,先ほどの計算を進めることができます。

(28)qi(Zi)=arg minqi {Eqi[lnqi(Zi)+constEqji[lnp(Z|X)]]}(29)=arg minqi {Eqi[lnqi(Zi)Eqji[lnp(Z|X)]]}(30)=arg minqi {Eqi[lnqi(Zi)exp(Eqji[lnp(Z|X)])]}(31)=arg minqi {Eqi[lnqi(Zi)exp(Eqji[lnp(Z|X)])/ClnC]}(32)=arg minqi KL[qi(Zi)exp(Eqji[lnp(Z|X)])C]

KLダイバージェンスを最小にするのは2つの分布が等しいときですので,結局以下が得られます。

(33)qi(Zi)=exp(Eqji[lnp(Z|X)])C

両辺の対数を取ると,冒頭で説明した公式(15)が得られます。

(34)lnqi(Zi)=Eqji[lnp(Z|X)]+const(35)=Eqji[lnp(X,Z)p(X)]+const(36)=Eqji[lnp(X,Z)]+const

「自分以外全ての潜在変数・パラメータで仮定した確率モデルの期待値を取ると近似事後分布の形が得られる」と理解しましょう。冒頭にもお伝えした通り,我々の目標は得られたデータXの背後に潜む事後分布p(Z|X)の形を求めることでした。確率モデルの問題では,確率変数同士の依存関係を同時分布p(X,Z)としてこちら側で定めてしまうのでした。ゆえに,式(15)の右辺を計算することができます。

EMアルゴリズムの類推

ここまでの方針は,p(Z|X)を良く表す近似事後分布q(Z)を,KLダイバージェンスという恣意的な指標を用いて測りました。そこで,ここからはKLダイバージェンスが出てくる必然性を説明していきます。

思い出して欲しいのは,EMアルゴリズムにおける下限の導出過程です。下限と対数尤度関数の差は,KLダイバージェンスの形になりましたね。この経験から類推するに,変分ベイズでも目的関数に対してイェンセンの不等式を適用すれば,KLダイバージェンスが出現するはずです。

さて,変分ベイズの目的関数というのは一体何なのでしょうか。EMアルゴリズムの目的は対数尤度関数の最大化でしたので,目的関数が自明でした。しかし,変分ベイズの目的は事後分布pをよく表す近似事後分布qを見つけることでしたので,目的関数は自明ではありません。目的ドリブンで考えると,変分ベイズの目的関数がqpのKLダイバージェンスになるというのは「距離指標としてKLダイバージェンスを採用する必然性が(今のところ)ないこと」を除いて理にかなっています。

こういうときは,出発点に立ち返ることが大切です。私たちの目的は,ある現象を確率分布を用いて記述することです。そのためには,以下のステップが必要になるのでした。

  • ある現象をよく観察して最もよくフィットする既存の確率分布を選択する
  • 仮定した確率分布の形状を決定するパラメータを推定する

点推定では「パラメータの値」自体に興味があるため,対数尤度関数をパラメータの関数と読み替えて最大化問題を解きました。その際,潜在変数の出現による和の対数部分が計算困難であるため,イェンセンの不等式を利用して対数尤度関数を下から評価したのでした。

(37)lnp(X|θ)=LML[q(Z),θ]+KL[q(Z)p(Z|X,θ)]

変分ベイズでは,パラメータと潜在変数を同一視しますので,θq(Z)に吸収されます。つまり,式(37)の両辺におけるθを単に消去して,q(Z)の中にθを含めてあげればよいのです。

(38)lnp(X)=LVB[q(Z)]+KL[q(Z)p(Z|X)]

なぜパラメータを潜在変数に含めるのかについては,変分ベイズではパラメータの値自体には興味がなく,パラメータの分布に興味があるからです。興味の対象はθではなくqになりますので,θZに含めてしまってqだけに注目するように仕向けたのです。

すると,変分ベイズの目的関数は式(38)の左辺,すなわち周辺対数尤度関数であることが分かりました。ここで注意するべきなのは,Xは観測データですので周辺対数尤度関数は定数であるということです。この事実は変分ベイズ法の解説でよく誤解されている点ですので,おさえておくと良いと思います。

このことから,周辺対数尤度関数を目的関数に掲げるのは不適切といえます。そこで,式(38)の右辺に注目しましょう。対数周辺尤度関数の下限とKLダイバージェンスから構成されていますね。何を隠そう,第二項目のKLダイバージェンスは前の方針で考えたqpの距離を測る指標そのものです。したがって,変分ベイズの目的関数はqpのKLダイバージェンスに帰着します。

先ほど「距離指標としてKLダイバージェンスを採用する必然性が(今のところ)ない」とお伝えしましたが,KLダイバージェンスを用いる必然性は,目的関数の下限をイェンセンの不等式で評価していたことに裏付けられているのです。イェンセンの不等式を用いると言う前提に立つ場合には,KLダイバージェンスを用いる必然性は担保されるということです。

ここまでをまとめます。変分ベイズの目的関数は対数周辺尤度関数ですが,一定値であるため目的関数としては不適切です。そこで,対数周辺尤度関数を構成する2つの項に注目すると,pをよく近似するqを見つけるためにKLダイバージェンスを目的関数に設定すれば良いことが分かりました。

一方で,式(38)の右辺の和が一定値であることに注意すると,KLダイバージェンスを最小化することは下限を最大化することと等価です。したがって,変分ベイズでは2つの等価な目的関数が存在することになります。ただし,KLダイバージェンスでは最小化問題,下限では最大化問題を考えるという点に注意してください。

最後に,EMアルゴリズムと変分ベイズの比較表を載せておきます。

EMアルゴリズムと変分ベイズの比較

LVB[q(Z)]のことをELBO(evidence lower bound)や変分下限(variational lower bound: VLB)と呼びます。先ほどお伝えしたように,lower boundは下界ですから,VLBは正確には変分下界と訳されるべきですが,変分ベイズでは専らイェンセンの不等式により下界を導出することから,下界を下限と呼ぶようになったと推察されます。

最尤推定とMAP推定とのつながり

本章では,変分ベイズの最尤推定とMAP推定とのつながりを説明していきましょう。結論からお伝えすると,最尤推定とMAP推定は,変分ベイズの特殊な場合に相当します。このような文脈で,変分ベイズは点推定の上位互換の概念であるといえるでしょう。KLダイバージェンスを目的関数と捉えた場合でも,下限を目的関数と捉えた場合でも,最尤推定とMAP推定とのつながりを美しく説明することができます。

変分ベイズがどのような場合でも必ず点推定よりも優れているという訳ではありません。上位互換というのはあくまでも概念としての包含関係の話であって,パラメータ推定の性能に関する話ではありません。実際には,与えられたデータ量や利用できるリソースに応じて,点推定とベイズ推定のメリットとデメリットを比較しながら臨機応変に使い分ける必要があります。

KL最小化

ベイズ推論ではパラメータの近似事後分布を求めるのでした。点推定ではパラメータの値を求めるのでした。そこで,ベイズ推論における近似事後分布を「一点にしか値を持たない」関数とすれば,ベイズ推論は点推定と等価になります。数学の世界では「一点にしか値を持たない」関数としてディラックのデルタ関数が有名です。

ディラックのデルタ関数

任意の実連続関数f:RRに対し,

(39)f(x)δ(x)dx=f(0)

を満たす実数値シュワルツ超関数δをディラックのデルタ関数と呼ぶ。

例えば,ディラックのデルタ関数をaだけ平行移動させると,以下のようにf(a)が抽出されます。

(40)f(x)δ(xa)dx=f(a)

したがって,パラメータを含む潜在変数の最適解をZ^とおくと,近似事後分布にディラックのデルタ関数δ(ZZ^)を仮定すればよいことが分かります。

KL[δ(ZZ^)p(Z|X)](41)=δ(ZZ^)lnδ(ZZ^)p(Z|X)dZ(42)=δ(ZZ^)lnδ(ZZ^)dZδ(ZZ^)lnp(Z|X)dZ(43)=lnδ(Z^Z^)lnp(Z^|X)(44)=lnδ(0)lnp(Z^|X)

結局,最適解Z^は以下のように表されます。lnδ(0)は定義されませんが,Z^とは関係ないため無視できる点がポイントです。

(45)arg minZ^ KL[δ(ZZ^)p(Z|X)]=arg minZ^ {lnδ(0)lnp(Z^|X)}(46)=arg maxZ^ lnp(Z^|X)(47)=arg maxZ^ p(Z^|X)

これはMAP推定そのものを表していますよね。美しいです。さらに,MAP推定において事前分布が一様(つまりパラメータに関する情報が得られない)と仮定したケースは,最尤推定に相当するはずです。そこで,式(47)において事前分布として定数p(Z^)=CMLを仮定してみましょう。

(48)arg minZ^ KL[δ(ZZ^)p(Z|X)]=arg maxZ^ lnp(Z^|X)(49)=arg maxZ^ lnp(X|Z^)p(Z^)p(X)(50)=arg maxZ^ {lnp(X|Z^)+lnp(Z^)lnp(X)}(51)=arg maxZ^ {lnp(X|Z^)+CMLlnp(X)}(52)=arg maxZ^ lnp(X|Z^)(53)=arg maxZ^ p(X|Z^)

これは最尤推定そのものを表していますよね。美しいです。ここまでの流れをまとめておきましょう。

変分ベイズのKL最小化の文脈において,以下の仮定と推定方法が対応する。

  • 近似事後分布にディラックのデルタ関数を仮定
    • MAP推定に相当
  • 近似事後分布にディラックのデルタ関数を仮定かつ事前分布に定数を仮定
    • 最尤推定に相当

シュワルツ超関数は単に超関数とも呼ばれていて,関数を拡張した概念です。ディラックのデルタ関数は元々,空間の一点にだけ存在する粒子を表現するために考案された関数です。上の定義の通り,積分して初めて意味をもつことから,一般的な関数とは異なる超関数として定義されています。

下限最大化

下限最大化の文脈においては,KL最小化の文脈よりもシンプルにMAP推定と最尤推定との繋がりを示すことができます。EMアルゴリズムにおける下限LMLの定義より,変分ベイズにおける下限LVBは以下のように計算されます。ただし,変分ベイズでは潜在変数を連続値として扱うため,シグマを積分で置き換えています。

(54)LVB[q(Z)]=Zq(Z)lnp(X,Z)q(Z)dZ(55)=Zq(Z)lnp(X|Z)p(Z)q(Z)dZ(56)=Zq(Z)lnp(X|Z)dZZq(Z)lnq(Z)p(Z)dZ(57)=Zq(Z)lnp(X|Z)dZKL[q(Z)p(Z)]

変分ベイズでは,この下限を最大化していくのでした。したがって,変分ベイズでは式(57)の第一項目の下限は最大化され,第二項目のKLダイバージェンスは最小化されます。同時に二つの項を考えると大変ですので,それぞれを最大化する場合について考えてみましょう。

第一項目の最大化は対数尤度に対応しますので,第一項目だけを考えると最尤推定に相当することが分かります。実際,第一項目を最大にするqは,対数尤度を最大にするZ^で値をもつディラックのデルタ関数δ(ZZ^)になります。なぜなら,期待値は分布の平均に相当する概念を表しますが,期待値を最大にするためには分布の最大値を抽出すれば良いからです。この結果は,KL最小化文脈の考察と一致します。

第二項目は負のKLダイバージェンスに対応しますので,第二項目だけを考えると「事後分布q(Z)を事前分布p(Z)に近づける」ことに相当することが分かります。これは,MAP推定における事前分布の正則化項に対応しており,KL最小化文脈の考察と一致します。

結局,下限を二つの項に分解すると以下のような結果になることが分かりました。用語に数学的な厳密性は担保していません。

(58)下限=qによる対数尤度の期待値+近似分布と事前分布の負の距離(59)=最尤推定項+正則化項

このように,下限を最大化することは,第1項目を大きくする(最尤推定の効果を大きくする)か,第2項目を大きくする(正則化の効果を大きくする)かのバランスを取っていることが分かりました。以上をまとめます。

変分ベイズの下限最大化の文脈において,下限は以下のように分解される。

(60)下限=最尤推定項+正則化項

混合ガウス分布への適用

本章では,混合ガウス分布(GMM:Gaussian Mixture Model)を題材に,EMアルゴリズムと変分ベイズを応用する方法をお伝えしていきます。

EMアルゴリズム

結論からお伝えすると,EMアルゴリズムのGMMへの適用は以下のようにまとめられます。

EMアルゴリズムのGMMへの適用(GMM-EM)は,以下の4ステップから構成される。

  1. 初期値設定
  2. 負担率rnkQ関数の計算(Eステップ)
  3. 各パラメータの更新(Mステップ)
  4. 収束判定

まずは各パラメータに初期値μ0Σ0π0を与える。次に,以下のEステップとMステップを収束するまで繰り返す。

Eステップ

以下の負担率rnkQ関数を計算する。

(61)rnk=πkN(xn|μk,Σk)j=1KπjN(xn|μj,Σj)(62)Q(θ)=n=1Nk=1Krnk[12{ln|Σk|(xnμk)TΣk1(xnμk)}]

Mステップ

Q関数を最大にするように以下のパラメータを計算する。

(63)Nk=n=1Nrnk(64)μk=n=1NrnkxnNk(65)Σk=n=1Nrnk(xnμk)(xnμk)TNk(66)πk=NkN

詳しい導出方法は「EMアルゴリズムをはじめからていねいに」をご参照下さい。

変分ベイズ

本節では,以下の流れで変分ベイズをGMMのパラメータ推論に適用していきます(GMM-VB)。

  1. 同時分布の依存関係確認
  2. 平均場近似のグループ分けの仮定
  3. 更新式適用による変分下限の最大化

1. 同時分布の依存関係確認

本節では,潜在変数とパラメータの依存関係をグラフィカルモデルを通して確認します。GMMのパラメータは,潜在変数も含めると,Z,π,μ,Σです。今回仮定する依存関係は,一般的によく利用されるものです。以下にグラフィカルモデルを示します。

GMM-VBのグラフィカルモデル

これを数式を用いて表すと,以下のようになります。

(67)p(X,Z,π,μ,Σ)=p(X|Z,μ,Σ)p(Z|π)p(π)p(μ|Σ)p(Σ)

今回は,各分布を以下のように定めます。

(68)p(X|Z,μ,Σ)=n=1Nk=1KN(xn|μk,Σk1)znk(69)p(Z|π)=n=1Nk=1Kπkznk(70)p(π)=Dir(π|α0)(71)p(μ|Σ)=k=1KN{μk|m0,(β0Σk)1}(72)p(Σ)=k=1KW(Σk|W0,ν0)

EMアルゴリズムとの対応を図るために,Zの周辺尤度関数も確認しておきます。

(73)p(X|μ,Σ,π)=Zp(X,Z|μ,Σ,π)(74)=Zp(X|Z,μ,Σ)p(Z|π)(75)=Z[n=1Nk=1K{πkN(xn|μk,Σk1)}znk](76)=n=1Nk=1KπkN(xn|μk,Σk1)

ただし,最終行は「one-hot-vectorZの全ての候補について足し合わせる」操作を「全てのクラスkについて和をとる」操作に読み替えて変形しました。この結果は,GMM-EMにおける以下の周辺尤度関数と矛盾しません。

(77)p(xn|π,μ,Σ)=k=1KπkN(xn|μk,Σk)

繰り返しにはなりますが,変分ベイズはEMアルゴリズムとは異なりベイズ推定を行いますので,尤度関数に加えてπμΣの事前分布を設定します。

分散共分散行列に関するガウス分布の共役事前分布は逆ウィシャート分布になります。一方で,分散共分散行列の逆行列に関するガウス分布の共役事前分布はウィシャート分布になります。今回は,素朴なウィシャート分布を持ち出すために,ガウス分布の分散共分散行列を逆行列で定めます。なお,分散共分散行列の逆行列のことを精度行列(precision matrix)と呼びます。

以下では,簡単に尤度関数と事前分布の根拠を説明しておきます。

EMアルゴリズムと同様に,Zの条件付き分布はznπの性質から自然に導かれる関係です。znはone-hot-vectorですので,znの生成確率をπkを利用して表現すれば,上のように累乗と総乗で表すことができるのです。データの条件付き確率も同様に,znがone-hot-vectorであることを利用して表現できます。

πの事前分布がディリクレ分布であるのは,p(Z|π)が多項分布の形(n=1のカテゴリカル分布)をしており,その共役事前分布として定めているからです。同様に,多変量正規分布p(X|Z,μ,Σ)μに関する共役事前分布としてガウス分布,精度行列(分散共分散行列の逆数)に関する共役事前分布としてウィシャート分布を設定しています。

多くの書籍やWeb上の資料では「p(μ,Σ)の共役事前分布としてガウスウィシャート分布を仮定する」と説明されています。しかし,本稿では上で示したグラフィカルモデルに基づいた説明を心掛けています。すなわち,まずΣが生成されて,そのΣに依存してμが生成されるというフローを忠実に再現するために,μに関する共役事前分布とΣに関する共役事前分布を別々に説明しました。

以下では,これらの分布を使って計算を行うために定義を確認しておきましょう。具体的には,ディリクレ分布多変量正規分布,ウィシャート分布の確率密度関数と期待値を知っておく必要があります。ディリクレ分布とウィシャート分布に関しては,後でE[lnx]を利用することになるため,ここで一緒に確認しておきましょう。

ディリクレ分布の確率密度関数・E[X]E[lnx]は以下のように表されます。

ディリクレ分布

(78)fX(x)=1B(α)i=1nxiαi1(79)E[Xi]=αii=1nαi(80)E[lnXi]=ψ(αi)ψ(i=1nαi)

ただし,B()はベータ関数,ψ()はディガンマ関数を表す。

(81)B(a,b)=01xa1(1x)b1dx(82)ψ(a)=ddalnΓ(a)

多変量正規分布の確率密度関数・E[X]は以下のように表されます。

多変量正規分布

(83)fX(x)=1(2π)d/2|Σ|1/2exp{12(xμ)TΣ1(xμ)}(84)E[X]=μ

ウィシャート分布の確率密度関数・E[X]E[lnx]は以下のように表されます。

ウィシャート分布

(85)W(Σ|W,ν)=C(W,ν)|Σ|(νD1)/2exp(12Tr[W1Σ])(86)E[Σ]=νW(87)E[ln|Σ|]=i=1Dψ(ν+1i2)+Dln2+ln|W|

ただし,C(,)は以下で表される定数,ψ()はディガンマ関数である。

(88)C(W,ν)=|W|ν/2{2νD/2πD(D1)/4i=1DΓ(ν+1i2)}1(89)ψ(a)=ddalnΓ(a)

特に,Wが対称行列であることに注意されたい。

分布を利用する準備が整いました。これにて「1. 同時分布の依存関係の確認」が完了しました。

2. 平均場近似のグループ分けの仮定

本節では,平均場近似のグループ分けを考えていきます。今回は,潜在変数と他のパラメータが独立であることを仮定しましょう。

(90)q(Z,π,μ,Σ)=q(Z)q(π,μ,Σ)

この仮定はZ{π,μ,Σ}が独立であることを仮定しているだけであり,Zπが独立であることは仮定していません。Zπが独立であるための必要十分条件は,Z,πの同時分布がそれぞれの確率分布の積で分離されることであり,Z,π以外も含まれる同時分布からZのみが分離されたからといってZπが独立になるとは限りません。

「1. 同時分布の依存関係確認」で仮定したように,π(μ,Σ)は独立ですから,q(π,μ,Σ)q(π)q(μ,Σ)に分解してもよいです。しかし,次節で計算して分かる通り,分解せずともπμ,Σが独立である結果が導かれますから,ここではより弱い仮定を採用することにします。これにて「2. 平均場近似のグループ分けの仮定」が完了しました。

3. 更新式適用による変分下限の最大化

本節では,変分ベイズの更新式を用いて変分下限を最大化します。「2. 平均場近似のグループ分けの仮定」では,潜在変数Zの属する分布とパラメータπ,μ,Σの属する分布が別々であると仮定しました。そこで,以下では潜在変数に関する最適化を考えた後に,パラメータに関する最適化を考えます。EMアルゴリズムの類推により,前者の最適化をEステップ,後者の最適化をMステップとみなせます。

管理人の考えでは,変分ベイズではEステップとMステップを分ける意味はあまりないと考えています。なぜなら,変分ベイズではEMアルゴリズムとは異なり,パラメータと潜在変数を同一視するからです。一方で,本稿の趣旨はEMアルゴリズムとの比較を通じて変分ベイズをはじめからていねいに理解することです。そのため,今回は変分ベイズにおけるEステップを変分Eステップ,Mステップを変分Mステップと表記することにします。詳しくは末尾の付録で説明します。

最初に結論からお伝えします。

変分ベイズのGMMへの適用(GMM-VB)は,以下の4ステップから構成される。

  1. 初期値設定
  2. 負担率rnkの計算(変分Eステップ)
  3. 各パラメータの更新(変分Mステップ)
  4. 収束判定

まずは各パラメータに初期値α0β0ν0m0W0を与える。次に,以下の変分Eステップと変分Mステップを収束するまで繰り返す。

変分Eステップ

負担率rnkを計算する。

(91)rnk=ρnkj=1Kρnj

ただし,ρnkは以下で計算できる。

(92)lnπ~k=ψ(αk)ψ(k=1Kαk)(93)lnΣ~k=i=1Dψ(νk+1i2)+Dln(2)+ln|Wk|(94)ρnk=π~kΣ~k12exp{D2βkνk2(xnmk)TWk(xnmk)}

変分Mステップ

パラメータを更新する。

(95)αk=α0+Nk(96)βk=β0+Nk(97)νk=ν0+Nk(98)mk=1βk(β0m0+Nkxk)(99)Wk1=W01+NkSk+β0Nkβ0+Nk(xkm0)(xkm0)T

ただし,NkxkSkは変分Eステップで計算した負担率rnkを用いて計算できる。

(100)Nk=n=1Nrnk(101)xk=1Nkn=1Nrnkxn(102)Sk=1Nkn=1Nrnk(xnxk)(xnxk)T

GMM-VBのアニメーション
変分Eステップ

まずは変分Eステップの導出から始めましょう。最適なqであるq(Z)を求めます。変分ベイズの更新式を適用すると,以下のようにq(Z)を計算することができます。

(103)lnq(Z)=Eπ,μ,Σ[lnp(X,Z,π,μ,Σ)]+const(104)=Eπ[lnp(Z|π)]+Eμ,Σ[lnp(X|Z,μ,Σ)]+const=n=1Nk=1KznkE[lnπk]+12E[ln|Σk|]D2ln(2π)(105)12Eμk,Σk[(xnμk)TΣk(xnμk)]+const(106)=n=1Nk=1Kznklnρnk+const

ただし,lnρnkは以下のように置きました。

lnρnk=E[lnπk]+12E[ln|Σk|]D2ln(2π)(107)12Eμk,Σk[(xnμk)TΣk(xnμk)]

一旦,定数項constを無視すると,q(Z)は以下のように求められます。

(108)q(Z)n=1Nk=1Kρnkznk

以下では,厳密にq(Z)を定めていきます。すなわち,先ほど無視した定数項constによる正規化定数を求めます。求める正規化定数をAとおくと,

(109)1A=Zn=1Nk=1Kρnkznk(110)=(z1k=1Kρ1kz1k)(z2k=1Kρ2kz2k)(111)=(k=1Kρ1k)(k=1Kρ2k)(112)=n=1Nk=1Kρnk

と計算することができます。ただし,途中でzがone-hot-vectorであることを利用しました。したがって,q(Z)は正確に以下のように表されます。

(113)q(Z)=1An=1Nk=1Kρnkznk(114)=n=1Nk=1Kρnkznkn=1Nk=1Kρnk(115)=n=1Nk=1Kρnkznkk=1Kρnk(116)=k=1Nk=1Kρnkznkk=1K{j=1Kρnj}znk(117)=n=1Nk=1K{ρnkj=1Kρnj}znk(118)=n=1Nk=1Krnkznk

ただし,rnkは以下のように置きました。

(119)rnk=ρnkj=1Kρnj

このrnkがEMアルゴリズムにおける負担率となります。その証明としては,式(118)の期待値がrnkとなることから確認できます。znk1のときだけrnkが寄与するからです。

(120)E[znk]=rnk

この式は,EMアルゴリズムで定義した負担率と全く同じです。

最後に,式(107)と事前分布の仮定を利用して,rnkを計算するためにρnkを求めましょう。その前に,D次元ベクトルxμN(m,(βΣ)1),およびΣW(Σ|W,ν)に対し,

(121)Eμ,Σ[(xμ)TΣ(xμ)]=ν(xm)TW(xm)+β1D

が成り立つことを証明します。N(m,Σ)の確率密度関数をfN(m,Σ)とおいてEμ,Σの定義から計算すると,

(122)Eμ,Σ[(xμ)TΣ(xμ)]=(xμ)TΣ(xμ)fN(m,Σ)dμdΣ(123)={(xμ)TΣ(xμ)fN(m|Σ)dμ}fN(Σ)dΣ(124)=Eμ[(xμ)TΣ(xμ)]fN(Σ)dΣ(125)=Eμ[Tr{Σ(xμ)(xμ)T}]fN(Σ)dΣ(126)=Eμ[Tr(Σ(xxT2μxT+μμT))]fN(Σ)dΣ(127)=Tr[Eμ[(Σ(xxT2μxT+μμT))]]fN(Σ)dΣ(128)=Tr[Σ(xxT2Eμ[μ]xT+Eμ[μμT])]fN(Σ)dΣ(129)=Tr[Σ(xxT2mxT+mmT+(βΣ)1)]fN(Σ)dΣ(130)=Tr[Σ(xm)(xm)T+Σ(βΣ)1]fN(Σ)dΣ(131)=Tr[Σ(xm)(xm)T+β1ID]fN(Σ)dΣ(132)={Tr[Σ(xm)(xm)T]+Tr[β1ID]}fN(Σ)dΣ(133)={(xm)TΣ(xm)+β1D}fN(Σ)dΣ(134)=EΣ[Tr[Σ(xm)(xm)T]+β1D](135)=EΣ[Tr[Σ(xm)(xm)T]]+β1D(136)=Tr[EΣ[Σ(xm)(xm)T]]+β1D(137)=Tr[EΣ[Σ](xm)(xm)T]+β1D(138)=Tr[νW(xm)(xm)T]+β1D(139)=(xm)TνW(xm)+β1D(140)=ν(xm)TW(xm)+β1D

となり,式(121)が示されました。ただし,XTAX=Tr[AXXT]であること,分散共分散行列の定義より

(141)(βΣ)1=E[(μm)(μm)T](142)=E[μμT]2mE[μT]+mmT(143)=E[μμT]2mmT+mmT(144)=E[μμT]mmT

となりE[μμT]=mmT+(βΣ)1であること,トレースには線形性があること,ウィシャート分布の期待値が式(86)で表されること,および期待値とトレースは交換可能であることを利用しました。

期待値とトレースの交換

D次元正方行列Aを考える。

(145)A=(a11aDD)

期待値の線形性とトレースの定義により,期待値とトレースが交換可能であることが分かる。

(146)E[Tr[A]]=E[d=1Dadd](147)=d=1DE[add](148)=Tr[(E[a11]E[aDD])](149)=Tr[E[A]]

式(121)を今回の仮定μN{μk|mk,(βkΣk)1}およびΣkW(Σk|Wk,νk)に適用すると,

(150)Eμk,Σk[(xnμk)TΣk(xnμk)]=νk(xnmk)TWk(xnmk)+βk1D

が得られます。この結果を利用してlnρnkを計算すると,

lnρnk=E[lnπk]+12E[ln|Σk|]D2ln(2π)(151)12Eμk,Σk[(xnμk)TΣk(xnμk)]=ψ(αk)ψ(k=1Kαk)+12i=1Dψ(νk+1i2)+Dln(2)+ln|Wk|(152)12{νk(xnmk)TWk(xnmk)+βk1D}(153)=lnπ~k+lnΣ~k12{νk(xnmk)TWk(xnmk)+βk1D}

が得られます。ただし,ごちゃごちゃした項は以下のように置き換えました。

(154)lnπ~k=E[lnπk](155)=ψ(αk)ψ(k=1Kαk)(156)lnΣ~k=12E[ln|Σk|](157)=i=1Dψ(νk+1i2)+Dln(2)+ln|Wk|

変分ベイズのEステップでは,このρnkを計算してrnkを求めます。そのために,ダイレクトにrnkを更新できるような式を求めておきましょう。式(153)の定義からρnkを求めます。ρnkさえ求まれば,あとは正規化するだけで負担率rnkが計算できます。

(158)ρnk=exp[lnπ~k+12lnΣ~k12{νk(xnmk)TWk(xnmk)+βk1D}](159)=π~kΣ~k12exp{νk2(xnmk)TWk(xnmk)D2βk}

以上で,「3. 更新式適用による変分下限の最大化」における変分Eステップの導出が完了しました。

変分Mステップ

続いて,パラメータに関する最適化を行う変分Mステップの導出を行います。Mステップでは,変分ベイズの更新式を用いて潜在変数以外のパラメータに関して最適化を施していきます。変分ベイズの更新式を利用するためには,π,μ,Σの対数同時分布lnp(π,μ,Σ)を知る必要があります。「1. 同時分布の依存関係確認」の仮定に基づくと,lnp(π,μ,Σ)は以下のように表されます。

(160)lnp(π,μ,Σ)=lnp(X|Z,μ,Σ)+lnp(Z|π)+lnp(π)+lnp(μ|Σ)+lnp(Σ)

したがって,変分ベイズの更新式を用いると,π,μ,Σに関する近似事後分布は以下のように求められます。

(161)lnq(π,μ,Σ)=EZ[lnp(π,μ,Σ)]=k=1Kn=1NE[znk]lnN(xn|μk,Σ1)+EZ[lnp(Z|π)](162)+lnp(π)+k=1Klnp(μk,Σk)+const

以下では,πμ,Σに分けて近似事後分布を導出したいと思います。方針としては,式(162)において,πに関わる部分とμ,Σに関わる部分を別々に抽出します。ここでπμ,Σを別々に考えるのは,「1. 同時分布の依存関係確認」の仮定に基づいています。グラフィカルモデルを見ても分かる通り,μ,Σは依存関係にありますが,πは孤立しています。

多項分布p(Z|π)の共役事前分布としてディリクレ分布p(π)を仮定していますので,πに関する近似事後分布q(π)もまたディリクレ分布の形になります。今回は対数近似事後分布を考えていますので,式(162)のπに関わる部分はディリクレ分布の対数を取った形になるはずです。したがって,ディリクレ分布の対数を取った理想の形と係数比較を行うことで,πに関する近似事後分布を求めることができます。

μ,Σも同様に導出します。多変量正規分布p(X|Z,μ,Σ)の共役事前分布としてガウスウィシャート分布p(μ|Σ)p(Σ)を仮定していますので,近似事後分布q(μ|Σ)は多変量ガウス分布,q(Σ)はウィシャート分布になるはずです。したがって,多変量正規分布・ウィシャート分布の対数を取った理想の形と係数比較を行うことで,μ,Σに関する近似事後分布を求めることができます。

早速,式(162)においてπに関わる部分を抽出します。

(163)lnq(π)=k=1Kn=1Nrnklnπkln{B(α0)k=1Kπkα01}+const(164)=k=1Kn=1Nrnklnπk+k=1K(α01)lnπk+const(165)=k=1K(α01+n=1Nrnk)lnπk+const

ただし,ディリクレ分布のパラメータα0について,Kクラスの初期値の対称性を考えて全ての要素をα0としました。ここで,分かりやすさのため,

(166)Nk=n=1Nrnk=n=1NE[znk]

とおくと,以下のようにlnq(π)をキレイな形に変形することができます。

(167)lnq(π)=k=1K(α01+Nk)lnπk+const

先ほどもお伝えした通り,この形はディリクレ分布の対数をとったものになっているはずです。

(168)lnq(π)=lnDir(π|α)=k=1K(αk1+Nk)lnπk+const

すると,以下の恒等式が成り立ちます。

(169)k=1K(α0+Nk1)lnπk=k=1K(αk1)lnπk

係数比較により,αkに関する更新式が得られます。

(170)αk=α0+Nk

これにて,πに関する最適化は終了です。続いて,μ,Σに関する最適化を行います。式(162)において,μ,Σに関わる部分を抽出します。

lnq(μ,Σ)=k=1KlnN(μ|m0,(β0Σk)1)+k=1KlnW(Σk|W0,ν0)(171)+n=1Nk=1KE[znk]lnN(xn|μk,Σk1)+constk=1K{12ln|β0Σk|12(μkm0)Tβ0Σk(μkm0)+ν0D12ln|Σk|12Tr(W01Σk)(172)+12Nkln|Σk|12n=1Nrnk(xnμk)TΣk(xnμk)}

ここで注意するべきなのは,最後の項で早とちりしてn=1Nrnk=Nkと変形しないことです。なぜなら,nの中身にxnが含まれているため,Nkの定義とは一致しないからです。ここは,引っかかりポイントだと思います。

さて,ここからはグラフィカルモデルに基づいてμΣの依存関係を明確にしましょう。μΣに依存していますから,ベイズの定理よりq(μ,Σ)は理想的には以下のように分解されます。

(173)q(μ,Σ)=q(μ|Σ)q(Σ)

両辺の対数を取ります。

(174)lnq(μ,Σ)=lnq(μ|Σ)+lnq(Σ)

したがって,式(172)においてμに関する項だけを抽出した結果はlnq(μ|Σ)となることが分かります。

lnq(μ|Σ)(175)=12k=1K{(μkm0)Tβ0Σk(μkm0)+Nk(xnμk)TΣk(xnμk)}

先ほどもお伝えした通り,この形は多変量正規分布の対数をとったものになっているはずです。

(176)lnq(μ|Σ)=k=1KlnN(μk|mk,βkΣk)(177)=k=1K{12|Σk|βk2(μkmk)TΣk(μkmk)+const}

πのときと同様に,式(172)と式(177)の恒等式から係数比較により更新式を導出していきます。そこで,一旦指数部のμkTμkの項だけに注目して係数比較を行ってみましょう。式(172)において,μkTμkに関する項は以下です。

(178)μkTβ0Σkμk+NkμkTΣkμk=μkT(β0+Nk)Σkμk

式(177)において,μkTμkに関する項は以下です。

(179)μkTβkΣkμk

したがって,μkの二次の項であるμkTμkに着目すると,以下の恒等式が成り立ちます。

(180)μkT(β0+Nk)Σkμk=μkT(βk)Σkμk

係数比較により,βkに関する更新式が得られます。

(181)βk=β0+Nk

αkに関する更新式(170)と対称的な結果が得られて美しいですね。しかし,まだmkに関する更新式が得られていません。そこで,指数部のμkの一次の項だけに注目して係数比較を行ってみましょう。式(172)において,μkに関する項は以下です。

(182)μkTβ0Σkm0+n=1NrnkμkTΣkxn=μkTΣk(β0m0+Nkxk)

式(177)において,μkに関する項は以下です。

(183)μkTΣkβ0mk

したがって,以下の恒等式が成り立ちます。

(184)βkmk=β0m0+Nkxk

係数比較により,mkに関する更新式が得られます。

(185)mk=1βk(β0m0+Nkxk)

最後に,Σに関する更新式を導出していきたいと思います。式(174)を変形すると,lnq(Σ)は以下のように表されます。

(186)lnq(Σ)=lnq(μ,Σ)lnq(μ|Σ)

式(172)と式(177)から式(186)を計算していきます。

lnq(Σ)=k=1K{12ln|β0Σk|12(μkm0)Tβ0Σk(μkm0)+ν0D12ln|Σk|12Tr(W01Σk)+12Nkln|Σk|12n=1Nrnk(xnμk)TΣk(xnμk)(187)+12|Σk|+βk2(μkmk)TΣk(μkmk)+const}

先ほどもお伝えした通り,この形はウィシャート分布の対数をとったものになっているはずです。

(188)lnq(Σ)=k=1K(νkD12ln|Σk|12Tr[Wk1Σk])

ln|Σk|に注目すると,以下の恒等式が成り立ちます。

(189)ν0D12ln|Σk|=νkD12ln|Σk|

係数比較により,νkに関する更新式が得られます。

(190)νk=ν0+Nk

αkに関する更新式(170),βkに関する更新式(181)と対称的な結果が得られて美しいですね。しかし,まだWkに関する更新式が得られていないため恒等式を立てます。ここでは,トレースに関する以下の3つの性質を利用します。

(191)xTAx=Tr[AxxT](192)Tr[A]+Tr[B]=Tr[A+B](193)Tr[AT]=Tr[A]

上から順番に,トレースと二次形式の関係,トレースの線形性,トレースと転置の関係を表しています。これらの性質を念頭に置くと,式(187)におけるΣkに関する項は以下のように整理されます。

12k=1K{Tr[β0Σk(μkm0)(μkm0)T]+Tr[W01Σk](194)+Tr[n=1NrnkΣk(xnμk)(xnμk)T]Tr[βkΣk(μkmk)(μkmk)T]}=12k=1KTr[β0Σk(μkm0)(μkm0)T+W01Σk(195)+n=1NrnkΣk(xnμk)(xnμk)TβkΣk(μkmk)(μkmk)T]=12k=1KTr[β0Σk(μkm0)(μkm0)T+ΣkW01(196)+n=1NrnkΣk(xnμk)(xnμk)TβkΣk(μkmk)(μkmk)T]

ただし,W0Σkはどちらも対称行列であることを利用しました。一方,式(188)におけるΣkに関する項も,トレースの転置に関する性質を利用すると以下のように整理することができます。

(197)12k=1KTr[Wk1Σk]=12k=1KTr[ΣkWk1]

したがって,得られる恒等式は以下のようになります。

ΣkWk1=β0Σk(μkm0)(μkm0)T+ΣkW01(198)+n=1NrnkΣk(xnμk)(xnμk)TβkΣk(μkmk)(μkmk)T

両辺の左からΣk1を掛けると,Wkに関する以下の更新式が得られます。

Wk1=W01+β0(μkm0)(μkm0)T(199)+n=1Nrnk(xnμk)(xnμk)Tβk(μkmk)(μkmk)T

式(199)をキレイに表すために,以下のように新しい変数を導入します。

(200)xk=1Nkn=1Nrnkxn(201)Sk=1Nkn=1Nrnk(xnxk)(xnxk)T

これらを利用して,式(199)を整理していきます。ポイントとなるのは,今回はlnq(Σ)を考えていますので,μkの項は打ち消されるという点です。そこで,以下ではμkの項を無視して考えていきます。式(199)を整理する際に必要となる計算を予め行っておきます。

n=1NrnkxnxnT=n=1Nrnk(xnxk)(xnxk)T(202)+2n=1Nrnkxnxkn=1NrnkxkxkT(203)=n=1Nrnk(xnxk)(xnxk)T+2NkxkxkTxkxkT(204)=n=1Nrnk(xnxk)(xnxk)T+NkxkxkT(205)=NkSk+NkxkxkT

式(199)を整理しましょう。βkの更新式である式(181)とmkの更新式である式(185)を利用します。先ほどもお伝えしましたが,μkを含む項を無視するのがポイントです。

(206)Wk1=W01+β0m0m0T+n=1NrnkxnxnTβkmkmkT(207)=W01+β0m0m0T+(NkSk+NkxkxkT)βkmkmkT=W01+NkSk+β0m0m0T+NkxkxkT(208)(β0+Nk)1(β0+Nk)2(β0m0+Nkxk)(β0m0T+NkxkT)=W01+NkSk+1β0+Nk{(β0+Nk)(β0m0m0T+NkxkxkT)(209)(β0m0+Nkxk)(β0m0T+NkxkT)}(210)=W01+NkSk+β0Nkβ0+Nk(xkm0)(xkm0)T

βk=β0+Nk を利用すれば,μkの項が消えることを計算しても示すことができます。

以上で,潜在変数と全てのパラメータに関する更新式が得られました。改めて,変分ベイズのGMMへの適用の流れをまとめておきます。

変分ベイズのGMMへの適用(GMM-VB)は,以下の4ステップから構成される。

  1. 初期値設定
  2. 負担率rnkの計算(変分Eステップ)
  3. 各パラメータの更新(変分Mステップ)
  4. 収束判定

まずは各パラメータに初期値α0β0ν0m0W0を与える。次に,以下の変分Eステップと変分Mステップを収束するまで繰り返す。

変分Eステップ

負担率rnkを計算する。

(211)rnk=ρnkj=1Kρnj

ただし,ρnkは以下で計算できる。

(212)lnπ~k=ψ(αk)ψ(k=1Kαk)(213)lnΣ~k=i=1Dψ(νk+1i2)+Dln(2)+ln|Wk|(214)ρnk=π~kΣ~k12exp{D2βkνk2(xnmk)TWk(xnmk)}

変分Mステップ

パラメータを更新する。

(215)αk=α0+Nk(216)βk=β0+Nk(217)νk=ν0+Nk(218)mk=1βk(β0m0+Nkxk)(219)Wk1=W01+NkSk+β0Nkβ0+Nk(xkm0)(xkm0)T

ただし,NkxkSkは変分Eステップで計算した負担率rnkを用いて計算できる。

(220)Nk=n=1Nrnk(221)xk=1Nkn=1Nrnkxn(222)Sk=1Nkn=1Nrnk(xnxk)(xnxk)T

GMM-VBのアニメーション

以上で,変分ベイズのGMMへの適用は完了です。初期値の設定や収束判定は,実装の章で確認します。

ここまで読了された賢明な読者の皆さまはお気づきかもしれませんが,変分ベイズの欠点は更新式の導出が煩雑になることです。今回採り上げた混合ガウス分布は比較的単純なケースなのですが,それでも上で確認したように計算がかなり煩雑になってしまいます。このような背景から,ベイズモデリングの実際の応用ではマルコフ連鎖モンテカルロ法(Markov chain Monte Carlo methods: MCMC)などのサンプリング法が用いられることがあります。

実装

本章では,GMMを題材にEMアルゴリズムに基づくクラスタリングと変分ベイズに基づくクラスタリングの実装をお伝えしていきます。具体的には,10000個の3次元データをGMM-EM(EMアルゴリズムによる混合ガウス分布の推論)とGMM-VB(変分ベイズによる混合ガウス分布の推論)を利用してクラスタリングする実装例をお伝えしていきます。要するに,以下のパラメータを仮定します。

(223)N=10000(224)D=3

実装はGithubで公開しています。

ソースコードのコメントは英語で書いています。これはGithubで公開する際に,外国の方々にも参考にしていただきたいからです。コメント規則であるdocstringはGoogleスタイルを利用しています。

ここからは,以下の形式でメソッド単位で解説を行っていきます。

[メソッド名・その他タイトルなど]

# ソースコード

[ソースコードの解説]

Github上のコードはDockerを用いて動かせるように整備していますが,下記のコードはGoogle Colaboratoryを用いて気軽に試せるようにしています。ぜひコードをコピペしてお試しください。

データの準備

以下では,GMM-EMとGMM-VBを適用するデータを生成していきます。

import sys # 引数の操作
import csv # csvの読み込み
import numpy as np # 数値計算
import matplotlib.pyplot as plt # 可視化
from numpy import linalg as la # 行列計算
from collections import Counter # 頻度カウント
from scipy.special import digamma, logsumexp # 数値計算
from scipy.stats import multivariate_normal # 多次元ガウス分布の確率密度関数の計算
from mpl_toolkits.mplot3d import Axes3D # 可視化

標準的なライブラリをインポートします。

# 各クラスターに属するデータ数
# 今回は全てのクラスタでデータ数が同じと仮定する
# 全体のデータ数は N = N1 + N2 + N3 + N4 となる
# 各クラスタのデータ数
N1 = 4000
N2 = 3000
N3 = 2000
N4 = 1000

# 平均
Mu1 = [5, -5, -5]
Mu2 = [-5, 5, 5]
Mu3 = [-5, -5, -5]
Mu4 = [5, 5, 5]

# 共分散
Sigma1 = [[1, 0, -0.25], [0, 1, 0], [-0.25, 0, 1]]
Sigma2 = [[1, 0, 0], [0, 1, -0.25], [0, -0.25, 1]]
Sigma3 = [[1, 0.25, 0], [0.25, 1, 0], [0, 0, 1]]
Sigma4 = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]

# 乱数を生成
X1 = np.random.multivariate_normal(Mu1, Sigma1, N1)
X2 = np.random.multivariate_normal(Mu2, Sigma2, N2)
X3 = np.random.multivariate_normal(Mu3, Sigma3, N3)
X4 = np.random.multivariate_normal(Mu4, Sigma4, N4)

データを作成します。今回は4つのガウス分布からデータを生成しましょう。クラスタリングとしては非常に簡単な問題設定になっています。なお,以下で載せている画像は当サイトのカラーリストを利用していますが,今回お伝えする実装はtab10colormapを利用します。本質的には重要でない部分ですので,気にしなくても大丈夫です。

# 描画準備
fig = plt.figure(figsize=(4, 4), dpi=300)
ax = Axes3D(fig)
fig.add_axes(ax)

# 当サイトのカスタムカラーリスト
cm = plt.get_cmap("tab10")   

# メモリを除去
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

# 少し回転させて見やすくする
ax.view_init(elev=10, azim=70)

# 描画
ax.plot(X1[:,0], X1[:,1], X1[:,2], "o", ms=0.5, color=cm(0))
ax.plot(X2[:,0], X2[:,1], X2[:,2], "o", ms=0.5, color=cm(1))
ax.plot(X3[:,0], X3[:,1], X3[:,2], "o", ms=0.5, color=cm(2))
ax.plot(X4[:,0], X4[:,1], X4[:,2], "o", ms=0.5, color=cm(3))
plt.show()
クラスタリングを行うデータ

可視化して確認していきましょう。

# 4つのクラスを結合
X = np.concatenate([X1, X2, X3, X4])

# 描画準備
fig = plt.figure(figsize=(4, 4), dpi=300)
ax = Axes3D(fig)
fig.add_axes(ax)

# メモリを除去
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

# 少し回転させて見やすくする
ax.view_init(elev=10, azim=70)

# 描画
ax.plot(X[:,0], X[:,1], X[:,2], "o", ms=0.5, color=cm(0))
plt.show()
クラスラベルが分かっていない状況でクラスタリングを行う

今回は,4つのクラスラベルを排除したものをクラスタリング対象としますので,データを結合しましょう。

# csvでデータを保存
np.savetxt("data.csv", X, delimiter=",")

毎回同じデータに対してクラスタリングを行うため,データをcsvに吐き出しておきます。当サイトのデモで利用しているcsvは以下でダウンロードできます。

# ご自身の環境におけるcsvへのパス
csv_dir = "[path to data.csv]"

# csvを読み取って(N, D)行列を生成
with open(csv_dir) as f:
    reader = csv.reader(f)
    X = [_ for _ in reader]
    for i in range(len(X)):
      for j in range(len(X[i])):
        X[i][j] = float(X[i][j])

# 後のためにnumpy化しておく
X = np.array(X)

当サイトと全く同じデータを利用したい場合は,上記ボタンからcsvをダウンロードしていただき,csvからデータを生成してください。ただし,[path to data.csv] にはご自身の環境におけるcsvファイルへのパスを挿入してください。この後から,EMアルゴリズムと変分ベイズで実装が分岐していきます。

EMアルゴリズム

GMMEMクラスの定義

詳しい実装内容は「EMアルゴリズムをはじめからていねいに」をご参照下さい。

GMMEMの実行

上で定義したGMMEMクラスのインスタンスを生成して,GMM-EMを実行しましょう。以下では,GMM-VBとの比較を行うためにいくつかの条件で実行してみます。

K=4の場合

# モデルをインスタンス化する
model = GMMEM(K=4)
# EMアルゴリズムを実行する
model.execute(X, iter_max=100, thr=0.001)

クラスタ数K4,最大更新回数は100回としました。収束判定における対数尤度増加幅の閾値は0.001としました。

K=4の場合

しっかりとクラスタリングできていますね。ログを確認すると,更新回数は10回だったことが分かります。

Log-likelihood gap: 19.66
Log-likelihood gap: 11.95
Log-likelihood gap: 4.53
Log-likelihood gap: 0.04
Log-likelihood gap: 0.74
Log-likelihood gap: 0.57
Log-likelihood gap: 0.01
Log-likelihood gap: 2.94
Log-likelihood gap: 1.65
Log-likelihood gap: 0.0
EM algorithm has stopped after 10 iteraions.

K=8の場合

# モデルをインスタンス化する
model = GMMEM(K=8)
# EMアルゴリズムを実行する
model.execute(X, iter_max=100, thr=0.001)

クラスタ数K8として,他のパラメータは先ほどと同様の設定で行いました。

K=8の場合

正しくクラスタリングされませんでした。ログを確認すると,更新回数は最大の100回となっていたため,うまく収束できなかったようです。点推定では,4つに分かれているクラスタをK=8の条件下では正しく推定できなかったということです。

Log-likelihood gap: 23.37
Log-likelihood gap: 8.06
Log-likelihood gap: 2.36
Log-likelihood gap: 3.5
Log-likelihood gap: 3.02
~~~~~~~~~~
省略
~~~~~~~~~~
Log-likelihood gap: 0.22
Log-likelihood gap: 0.2
Log-likelihood gap: 0.19
Log-likelihood gap: 0.18
Log-likelihood gap: 0.17
EM algorithm has stopped after 100 iteraions.

データの中に実質的な情報が僅かしか含まれていない性質のことを「スパース性」と呼びます。今回の場合は,仮定したクラス数よりも実際に推定するクラス数の方が少なく,EMアルゴリズムがスパース性を考慮できなかったために,クラスラベルを正しく推定することができなかったと考えられます。変分ベイズでは,後に述べる「関連度自動決定」により,スパース性を担保したクラスタリングが可能になります。

変分ベイズ

GMMVBクラスの定義

class GMMVB():

ここからは,GMMVBクラスを定義していきます。

コンストラクタ

    def __init__(self, K):
        """コンストラクタ

        Args:
            K (int): クラスタ数

        Returns:
            None.

        Note:
            eps (float): オーバーフローとアンダーフローを防ぐための微小量
        """
        self.K = K
        self.eps = np.spacing(1)

クラスタ数Kと微小量epsを定義します。eps0除算を防ぐため等に利用します。

パラメータ初期化メソッド

    def init_params(self, X):
        """パラメータ初期化メソッド

        Args:
            X (numpy ndarray): (N, D)サイズの入力データ

        Returns:
            None.
        """
        # 入力データ X のサイズは (N, D)
        self.N, self.D = X.shape
        # スカラーのパラメータセット
        self.alpha0 = 0.01
        self.beta0 = 1.0
        self.nu0 = float(self.D)
        # 平均は標準ガウス分布から生成
        self.m0 = np.random.randn(self.D)
        # 分散共分散行列は単位行列
        self.W0 = np.eye(self.D)
        # 負担率は標準正規分布から生成するがEステップですぐ更新するので初期値自体には意味がない
        self.r = np.random.randn(self.N, self.K)
        # 更新対象のパラメータを初期化
        self.alpha = np.ones(self.K) * self.alpha0
        self.beta = np.ones(self.K) * self.beta0
        self.nu = np.ones(self.K) * self.nu0
        self.m = np.random.randn(self.K, self.D)
        self.W = np.tile(self.W0[None,:,:], (self.K, 1, 1))
スクロールできます
変数サイズ意味
Nintデータ数
Dintデータの次元
alpha0floatディリクレ分布のパラメータ初期値
beta0floatガウス分布の分散共分散行列の係数初期値
nu0floatウィシャート分布のパラメータ初期値
m0(K, D)ガウス分布の平均
W0(K, D, D)ガウス分布の分散共分散行列
r(N, K)負担率
alpha(K)ディリクレ分布のパラメータ
beta(K)ガウス分布の分散共分散行列の係数
nu(K)ウィシャート分布のパラメータ
m(N, D)ガウス分布の平均
W(K, D, D)ウィシャート分布のパラメータ
フィールドで宣言している変数

パラメータの初期化を行います。

ディリクレ分布パラメータα0の初期値は0.01としました。α0πkの事前分布のパラメータですので,素朴に考えると1/Kが適切のように思えますが,変分ベイズではスパース性を担保した推定が行えることから,初期値として小さな値を採用しています。

ガウス分布のパラメータβ0の初期値は1.0としました。β0Σkのスケールを表していますので,初期値としては等倍である1.0を選択しました。

ウィシャート分布のパラメータν0の初期値は次元数Dとしました。上では指摘していませんが,ウィシャート分布におけるνD1よりも大きくなければいけません。ν0として1.00を与えてしまうと,推定が不安定な挙動を示してしまうため注意が必要です。

GMMの確率密度関数計算メソッド

    def gmm_pdf(self, X):
        """N個のD次元データに対してGMMの確率密度関数を計算するメソッド

        Args:
            X (numpy ndarray): (N, D)サイズの入力データ

        Returns:
            Probability density function (numpy ndarray): 各クラスタにおけるN個のデータに関して計算するため出力サイズは (N, K) となる
        """
        pi = self.alpha / (np.sum(self.alpha, keepdims=True) + np.spacing(1)) # (K)
        return np.array([pi[k] * multivariate_normal.pdf(X, mean=self.m[k], cov=la.pinv(self.nu[:,None,None] * self.W)[k]) for k in range(self.K)]).T # (N, K)

GMMの確率密度関数を計算します。EMアルゴリズムでは混合ガウス分布の混合率πk・平均μk・分散共分散行列Σkは更新対象のパラメータでしたが,変分ベイズではπkμkΣkには事前分布を設定しているために,それら自体は更新対象のパラメータではありません。言い換えれば,現時点ではGMMの確率密度関数を計算するためのπkμkΣkが手に入っていないということです。

このような背景から,πkμkΣkはそれぞれの事前分布から代表値を抽出してあげる必要があります。代表値として何を用いるのかは自明ではありませんが,ここでは単純に期待値を採用します。式(70)と式(79)より,πの期待値はディリクレ分布の期待値になります。

(225)E[πk]=αkj=1Kαj

同様に,式(71)と式(84)より,μの期待値はガウス分布の期待値になります。

(226)E[μk]=m0

同様に,式(72)と式(86)より,Σの期待値はウィシャート分布の期待値になります。

(227)E[Σk]=νkWk

これらの代表値を用いて,GMMの確率密度関数を計算します。計算自体はEMアルゴリズムと同様ですので,詳細は割愛します。

変分Eステップメソッド

    def e_step(self, X):
        """変分Eステップを実行するメソッド

        Args:
            X (numpy ndarray): (N, D)サイズの入力データ

        Returns:
            None.

        Note:
            以下のフィールドが更新される
                self.r (numpy ndarray): (N, K)サイズの負担率
        """
        # rhoを求めるために必要な要素の計算
        log_pi_tilde = (digamma(self.alpha) - digamma(self.alpha.sum()))[None,:] # (1, K)
        log_sigma_tilde = (np.sum([digamma((self.nu + 1 - i) / 2) for i in range(self.D)]) + (self.D * np.log(2) + (np.log(la.det(self.W) + np.spacing(1)))))[None, :] # (1, K)
        nu_tile = np.tile(self.nu[None,:], (self.N, 1)) # (N, K)
        res_error = np.tile(X[:,None,None,:], (1, self.K, 1, 1)) - np.tile(self.m[None,:,None,:], (self.N, 1, 1, 1)) # (N, K, 1, D)
        quadratic = nu_tile * ((res_error @ np.tile(self.W[None,:,:,:], (self.N, 1, 1, 1))) @ res_error.transpose(0,1,3,2))[:,:,0,0] # (N, K)
        # 対数領域でrhoを計算
        log_rho = log_pi_tilde + (0.5 * log_sigma_tilde) - (0.5 * self.D / (self.beta + np.spacing(1)))[None,:] - (0.5 * quadratic) # (N, K)
        # logsumexp関数を利用して対数領域で負担率を計算
        log_r = log_rho - logsumexp(log_rho, axis=1, keepdims=True) # (N, K)
        # 対数領域から元に戻す
        r = np.exp(log_r) # (N, K)
        # np.expでオーバーフローを起こしている可能性があるためnanを置換しておく
        r[np.isnan(r)] = 1.0 / (self.K) # (N, K)
        self.r = r # (N, K)

変分Eステップを実行します。具体的には,以下を計算した後に,

(228)lnπ~k=ψ(αk)ψ(k=1Kαk)(229)lnΣ~k=i=1Dψ(νk+1i2)+Dln(2)+ln|Wk|(230)ρnk=π~kΣ~k12exp{D2βkνk2(xnmk)TWk(xnmk)}

その結果を用いて以下を計算します。

(231)rnk=ρnkj=1Kρnj

EMアルゴリズム同様,オーバーフローを防ぐために対数領域で計算を行います。早速,負担率の対数を取ってみましょう。

(232)lnrnk=lnρnklnj=1Kρnj

EMアルゴリズム同様,lnρnkの計算は行列演算を用いて高速化を図ります。

rhoの行列演算

(1,K)(N,K)同士の和算・減算がありますが,ここでは(1,K)側を自動で(N,K)側に揃えてくれるブロードキャストというnumpyの機能を利用します。

改めて式(230)をみると,第二項目に対数領域ではないρnkが含まれています。ρnkには指数演算が含まれていますので,ρnk自体の計算は対数領域で行ったうえで,一番最後に指数演算を施して元に戻してあげた方がベターです。オーバーフローのリスクはなるべく一か所に集中させてあげた方が良いという思想です。

そこで,以下の変形を無理やり行うことで,第二項目でもlnρnkを利用して計算を行うことが可能になります。

(233)lnrnk=lnρnklnj=1Kexp(lnρnj)

ここで,第二項目に現れたlnexpscipy.speciallogsumexpというモジュールを利用することができます。

logsumexpは対数領域で計算を行う場合によく出現する項です。今回のケースのように,オーバーフローを防ぐための施策として頻繁に用いられる汎用的な手法ですので,ここでおさえておくとよいでしょう。

Mステップメソッド

    def m_step(self, X):
        """変分Mステップを実行するメソッド

        Args:
            X (numpy ndarray): (N, D)サイズの入力データ

        Returns:
            None.

        Note:
            以下のフィールドが更新される
                self.alpha (numpy ndarray): (K) サイズのディリクレ分布のパラメータ
                self.beta (numpy ndarray): (K) ガウス分布の分散共分散行列の係数
                self.nu (numpy ndarray): (K) サイズのウィシャート分布のパラメータ
                self.m (numpy ndarray): (K, D) サイズの混合ガウス分布の平均
                self.W (numpy ndarray): (K, D, D) サイズのウィシャート分布のパラメータ
        """
        # 各パラメータを求めるために必要な要素の計算
        N_k = np.sum(self.r, 0) # (K)
        r_tile = np.tile(self.r[:,:,None], (1, 1, self.D)).transpose(1, 2, 0) # (K, D, N)
        x_bar = np.sum((r_tile * np.tile(X[None,:,:], (self.K, 1, 1)).transpose(0,2,1)), 2) / (N_k[:,None] + np.spacing(1)) # (K, D)
        res_error = np.tile(X[None,:,:], (self.K, 1, 1)).transpose(0,2,1) - np.tile(x_bar[:,:,None], (1, 1, self.N)) # (K, D, N)
        S = ((r_tile * res_error) @ res_error.transpose(0,2,1)) / (N_k[:,None,None] + np.spacing(1)) # (K, D, D)
        res_error_bar = x_bar - np.tile(self.m0[None,:], (self.K, 1)) # (K, D)
        # 各パラメータを更新
        self.alpha = self.alpha0 + N_k #(K)
        self.beta = self.beta0 + N_k #(K)
        self.nu = self.nu0 + N_k #(K)
        self.m = (np.tile((self.beta0 * self.m0)[None,:], (self.K, 1)) + (N_k[:, None] * x_bar)) / (self.beta[:,None] + np.spacing(1)) # (K, D)
        W_inv = la.pinv(self.W0) + (N_k[:,None,None] * S) + (((self.beta0 * N_k)[:,None,None] * res_error_bar[:,:,None] @ res_error_bar[:,None,:]) / (self.beta0 + N_k)[:,None,None] + np.spacing(1)) # (K, D, D)
        self.W = la.pinv(W_inv) # (K, D, D)

変分Mステップを実行します。変分EステップはEMアルゴリズムと同様に負担率を求めるだけでしたので,負担率の計算方法を除いて特に変わりはありませんでした。一方で,変分Mステップでは更新対象のパラメータがごっそり変わるため実装もかなり変わってきます。具体的には,以下を計算した後に,

(234)Nk=n=1Nrnk(235)xk=1Nkn=1Nrnkxn(236)Sk=1Nkn=1Nrnk(xnxk)(xnxk)T

その結果を用いて以下を計算します。

(237)αk=α0+Nk(238)βk=β0+Nk(239)νk=ν0+Nk(240)mk=1βk(β0m0+Nkxk)(241)Wk1=W01+NkSk+β0Nkβ0+Nk(xkm0)(xkm0)T

EMアルゴリズム同様,Skの計算は行列演算を用いて高速化を図ります。

Sの行列演算

同様に,Wk1の計算も行列演算を用いて高速化を図ります。

W_invの行列演算

注意が必要なのは,フィールドではself.Wを逆行列として定めていないという点です。式(241)は逆行列W1に対する更新式ですので,self.Wをアップデートする際は逆行列にしてあげる必要があります。

逆行列の計算ですが,la.invを利用すると行列が正則でない場合に逆行列が計算できずにエラーが出てしまいます。そこで,正則でない行列に対しても逆行列(に似た性質の行列)を計算できるla.pinvを利用します。また,除算を行う際は0除算を防ぐために微小量epsを分母に加算していることに注意してください。

可視化メソッド

    def visualize(self, X):
        """可視化を実行するメソッド

        Args:
            X (numpy ndarray): (N, D)サイズの入力データ

        Returns:
            None.
        
        Note:
            このメソッドでは plt.show が実行されるが plt.close() は実行されない
        """
        # クラスタリングを実行
        labels = np.argmax(self.r, 1) # (N)
        # 利用するカラーを極力揃えるためクラスタを出現頻度の降順に並び替える
        label_frequency_desc = [l[0] for l in Counter(labels).most_common()]
        # tab10 カラーマップを利用
        cm = plt.get_cmap("tab10")   
        # 描画準備
        fig = plt.figure(figsize=(4, 4), dpi=300)
        ax = Axes3D(fig)
        fig.add_axes(ax)
        # メモリを除去
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])
        # 少し回転させて見やすくする
        ax.view_init(elev=10, azim=70)
        # 各クラスタごとに可視化を実行する
        for k in range(len(label_frequency_desc)):
            cluster_indexes = np.where(labels==label_frequency_desc[k])[0]
            ax.plot(X[cluster_indexes, 0], X[cluster_indexes, 1], X[cluster_indexes, 2], "o", ms=0.5, color=cm(k))
        plt.show()

VBが収束した後に,各クラスごとに色分けした可視化を行い結果を確認します。EMアルゴリズムと全く同じメソッドです。

実行メソッド

    def execute(self, X, iter_max, thr):
        """VBを実行するメソッド

        Args:
            X (numpy ndarray): (N, D)サイズの入力データ
            iter_max (int): 最大更新回数
            thr (float): 更新停止の閾値 (対数尤度の増加幅)

        Returns:
            None.
        """
        # パラメータ初期化
        self.init_params(X)
        # 各イテレーションの対数尤度を記録するためのリスト
        log_likelihood_list =[]
        # 対数尤度の初期値を計算
        log_likelihood_list.append(np.mean(np.log(np.sum(self.gmm_pdf(X), 1) + self.eps)))
        # 更新開始
        for i in range(iter_max):
            # Eステップの実行
            self.e_step(X)
            # Mステップの実行
            self.m_step(X)
            # 今回のイテレーションの対数尤度を記録する
            log_likelihood_list.append(np.mean(np.log(np.sum(self.gmm_pdf(X), 1) + self.eps)))
            # 前回の対数尤度からの増加幅を出力する
            print("Log-likelihood gap: " + str(round(np.abs(log_likelihood_list[i] - log_likelihood_list[i+1]), 2)))
            # もし収束条件を満たした場合,もしくは最大更新回数に到達した場合は更新停止して可視化を行う
            if (np.abs(log_likelihood_list[i] - log_likelihood_list[i+1]) < thr) or (i == iter_max - 1):
                print(f"VB has stopped after {i + 1} iteraions.")
                self.visualize(X)
                break

今まで定義してきたメソッドを駆使して,GMM-VBを実行します。初期化,変分Eステップ,変分Mステップ,収束判定と繰り返すことで,パラメータを更新していきます。収束判定では,「閾値」もしくは「最大更新回数」のいずれかが引っかかった場合に更新をストップするようにしています。EMアルゴリズム同様,閾値判定に用いる対数尤度は,全てのデータ点の和ではなく平均を利用しています。結局,EMアルゴリズムと全く同じメソッドです。

EMアルゴリズムの目的関数は対数尤度でしたので,収束条件に対数尤度を設定するのは合理的です。一方で,変分ベイズの目的関数はKLダイバージェンスもしくは下限でしたので,収束条件に対数尤度を設定するのは一見合理的ではありません。しかし,対数尤度が「パラメータが与えられたときの観測データの尤もらしさ」を表すことを踏まえれば,変分ベイズでも収束条件に対数尤度を設定するのは十分合理的です。

GMMVBの実行

上で定義したGMMVBクラスのインスタンスを生成して,GMM-VBを実行しましょう。以下では,GMM-VBとの比較を行うためにいくつかの条件で実行してみます。

K=4の場合

# モデルをインスタンス化する
model = GMMVB(K=4)
# VBを実行する
model.execute(X, iter_max=100, thr=0.001)

クラスタ数K4,最大更新回数は100回としました。収束判定における対数尤度の増加幅の閾値は0.001としました。

K=4の場合

しっかりとクラスタリングできていますね。ログを確認すると,更新回数は10回だったことが分かります。

Log-likelihood gap: 28.41
Log-likelihood gap: 0.08
Log-likelihood gap: 0.05
Log-likelihood gap: 0.01
Log-likelihood gap: 0.01
Log-likelihood gap: 0.03
Log-likelihood gap: 0.07
Log-likelihood gap: 0.14
Log-likelihood gap: 0.07
Log-likelihood gap: 0.0
VB has stopped after 10 iteraions.

K=8の場合

# モデルをインスタンス化する
model = GMMVB(K=8)
# VBを実行する
model.execute(X, iter_max=100, thr=0.001)

クラスタ数K8として,他のパラメータは先ほどと同様の設定で行いました。

K=8の場合

しっかりとクラスタリングできていますね。ログを確認すると,更新回数は6回だったことが分かります。

Log-likelihood gap: 12.22
Log-likelihood gap: 2.33
Log-likelihood gap: 1.33
Log-likelihood gap: 0.08
Log-likelihood gap: 0.01
Log-likelihood gap: 0.0
VB has stopped after 6 iteraions.

EMアルゴリズムのときとは異なり,クラスタ数が過剰に設定されている場合でも変分ベイズでは正しくクラスタリングが行われていることが分かりました。このように,変分ベイズがスパースなクラスタ分類を行うはたらきを「関連度自動決定」と呼びます。クラスタの個数を大きめ(K=8など)に設定すれば,勝手に適切なクラスタ数にフィットしてくれるというのです。

周辺尤度最大化するときにデータの傾向にそぐわない基底関数は周辺尤度を減少させる方向に寄与するため,勝手にクラスタ数をフィットしてくれるという訳です。詳しくは[1]の7.2.2章をご覧ください。

おわりに

本稿では,変分ベイズの目的とその実現方法を,点推定との比較を通じてボトムアップ的にお伝えしてきました。冒頭でも述べたように,変分ベイズは機械学習を学ぶ上で高く反り立つ壁です。多くの人が挫折を経験したことは想像に難くありません。本稿が,一人でも多くの人にとって変分ベイズを学ぶ助けとなれることを心から願っております。

参考文献

[1] Christopher M. Bishop, "Pattern Recognition and Machine Learning."

付録

本章では,EMアルゴリズムと変分ベイズのEステップとMステップの対応を考えます。

Zを構成する確率変数をZ0,,ZKと表すことにします。EMアルゴリズムでは潜在変数とパラメータは区別して,変分ベイズでは潜在変数にパラメータを含めるのでした。したがって,Z0,,ZKの中に,EMアルゴリズムにおけるθが対応していることに注意してください。本稿では,潜在変数とそれ以外を明示的に区別するため,Z0を潜在変数,それ以外のZ1,,ZKをパラメータと設定することにします。

GMMを例にした潜在変数の対応関係

変分ベイズは,言ってしまえば式(15)を計算するだけです。式(15)に基づけば,あるZkに対する更新を行うときは,Zk以外の値を固定します。そのため,変分ベイズはEMアルゴリズムのような交互更新を行う必要があります。K+1個のZに対して交互更新を行うため,理論上はK+1ステップ必要になります。

EMアルゴリズムと変分ベイズを比較する文脈では,変分Eステップと変分Mステップという用語が登場します。例えば,変分ベイズを混合ガウス分布に適用する場合,一般的には負担率を計算する変分Eステップとパラメータを更新する変分Mステップに分けられます。

変分ベイズにおいては,Mステップは存在しないというのが私の考えです。EMアルゴリズムにおけるMステップは,対数尤度を最大化するパラメータを点推定するステップでしたが,変分ベイズではパラメータを点推定するフェーズは存在しないからです。

多くの書籍やWeb上の資料では変分Eステップ・変分Mステップという用語が利用されているため,本稿では変分ベイズとEMアルゴリズムを対応付けたいと思います。いま,K種類のパラメータθと潜在変数Zが存在するとします。すると,EMアルゴリズムは,K種類のパラメータを固定して潜在変数Zに関する最適化を行うEステップと,潜在変数Zを固定してK種類のパラメータを点推定するMステップで構成されます。このような立場に立つと,EMアルゴリズムはEM...Mアルゴリズムと捉えることができます(MはK回連続しています)。

同じ立場に立てば,変分ベイズはk+1種類の確率変数Zがある状況ではEE...Eアルゴリズムと捉えることができます(EはK+1回連続しています)。なぜなら,変分ベイズが式(15)を計算するだけだからです。EMアルゴリズムにおけるK回のMステップが,全てEステップに置き換えられていますね。言ってしまえば,自分以外の期待値を取る操作を収束するまで繰り返すのが変分ベイズです。

EM...MアルゴリズムとEE...Eアルゴリズム

特に「EM」アルゴリズムと比較する場合は,潜在変数Zとそれ以外のパラメータθを別々に考えて,潜在変数以外の確率変数に関する式(15)を計算してZに関するKLを最小化するEZステップと,θ以外の確率変数に関する式(15)を計算してθに関するKLを最小化するEθステップに無理矢理二分割します。

VBを無理矢理二分割

このEZステップが巷では変分Eステップと呼ばれており,Eθステップが変分Mステップと呼ばれているようです。しかし,私はこの呼び方が変分ベイズに対して誤解を生んでいる気がしてなりません。EMアルゴリズムとは違って,変分Mステップでは上限を最大化している訳ではないからです。あくまでも,両手法のMステップにおいては,潜在変数以外の確率変数(EMアルゴリズムで言えばパラメータという値)を最適化しているという対応関係しかありません。

変分Eステップと変分Mステップ

上図においては,先ほど設定したようにZ0は潜在変数を表しています。Z0以外のパラメータは,先ほど説明した平均場近似の仮定をおくことで分離されているものとしています。変分Eステップでは,潜在変数Z0に関して最適なパラメータ(負担率と呼ぶ)を求めておきます。

この時点では,実際にZ0にとって最適なパラメータを探しただけであって,実際に更新はしていません。Z0は他のパラメータとは異なり潜在変数です。パラメータの更新を行うのは変分Mステップであり,Z0以外の因子についての最適化を行うことで変分下限を最大化します。このとき,変分Eステップで求めておいた負担率を利用することになるため,EMアルゴリズムと同様の交互更新になります。

しかし,本質的には変分ベイズでは潜在変数もパラメータも確率変数とみなしていますので,潜在変数とパラメータの立場は対等なはずです。変部ベイズにおいて「潜在変数とそれ以外」を区別する必要性が,私はあまりないように感じています。逆に言えば,両者を同一視した概念が変分ベイズだと思っています。

多くの書籍やWeb上の資料では,負担率という言葉だけが一人歩きしているように思えます。負担率の本質はEステップで計算される値であり,Mステップのパラメータ更新で利用される値であるということです。

シェアはこちらからお願いします!

コメント

コメント一覧 (39件)

  • 初めまして。高橋と申します。zuka様のWebサイトはPRMLの演習問題を解くときに大変お世話になりました。
    現在変分法を勉強しているとこだったのでとても勉強になっております。

    1点ご質問なのですが、対数尤度と下限の差を取りKLダイバージェンスを導出するとこの式変形についてです。書き忘れ、消し忘れだとは思うのですが、
    (21)から(22)にかけてlogP(X|Θ)は相殺されて消えると思うのですがいかがでしょうか。また(22)の式全体の符号はマイナスだとおもいました。

    突然の連絡大変失礼いたしました。これからも応援しております!

    • 高橋様

      ご連絡誠にありがとうございます!
      ご丁寧にご指摘いただき本当に助かります。
      本文を修正致しました。

      前のサイトからの読者様ということで,非常に嬉しく思います。
      これからも分かりやすく正確な発信を心掛けていますので,何卒よろしくお願い致します!

  • 素晴らしい記事をありがとうございます!
    EMアルゴリズム、変分ベイズ、transformerとよくわからない概念が出てくるたびにzuka様のサイトがとても参考になるので本当に感謝しています。

    ささいな点ですが、式番号(49)から(50)のところで符号が反転しているので、(50)から(57)のarg minはarg maxの事なのではないかと思いました。

    • 田中様

      ご指摘誠にありがとうございます!
      非常に助かります。

      式(49)-式(57)にかけて,argminargmaxに修正しました。
      いつも記事をご覧いただき嬉しい限りです。
      他にも疑問点や分かりにくい箇所等ございましたら,お気軽にお申し付けください!

  • 非常に細かいですが、「平均近似場」ではなく「平均場近似」ではないでしょうか?

    • mogmog 様

      ご指摘ありがとうございます!
      お恥ずかしい限りです。修正致しました。

  • はじめまして
    詳しい説明ありがとうございます。
    PRMLの演習解説はブックマークさせてもらっています。
    質問なんですが、式(31)~(32)でのInCとCの意味が理解できていません。
    定数だろうとは思うのですが、ここで引いたり分母を割っているのが...
    まだまだ、勉強不足です。
    教えていただけるとうれしいです。

    • sumeragi さま

      ご質問ありがとうございます。前のブログからの読者さまということで,非常に嬉しく思います。

      さて,式(31)~(32)でのInCとCの意味に関してですが,非常によい質問だと思います。本文中で定数を足し引きしたのは,式(15)の形を出現させるためです。式(15)は平均場近似を用いた変分推論の解として知られていますから,いわば天下り的に定数を足し引きしていると解釈してもらえればと思います。

      それでは,この定数が何を意味しているのかというと,分布の正規化定数を表しています。これはGMMへの適用が理解の助けになるでしょう。式(108)は式(15)におけるconstを無視して求められる解ですが,このままでは確率変数に関する総和が1とならず,qが確率分布の定義を満たしません。そこで,一旦式(15)におけるconstを無視して求められる解を計算してしまい,正規化定数は別途求めるというアプローチを取ります。このとき求められた正規化定数が,式(15)におけるconstに相当します。

      以上から,式(31)~(32)で定数を足し引きしているのは,正規化定数までをも含んだ形で変分推論の解を示すためだといえます。

  • いつもお世話になっております。大変分かりやすい記事ありがとうございます。
    さて、細かい話にはなるのですが、式(133)に関して、右辺第一項の、pi_kの乗数ですが、alpha_0ではなく、alpha_0kではないでしょうか?つまり、alpha_0はベクトルですので、その第k成分を乗ずるのではないでしょうか?
    これに伴って、式(134)、(135)も、シグマの外ではなく中にalpha_0が入ってくるものと思います。
    同様に式(135)のN_kに関しても、kのシグマの中に入れるべきではないでしょうか?

    これらが正しければ、式(138)から(141)までも少し修正すべきだと思われます。
    もし間違っていたら申し訳ありませんが、確認していただけますでしょうか?

    • Bayes様

      ご質問誠にありがとうございます。

      >式(133)に関して、右辺第一項の、pi_kの乗数ですが、alpha_0ではなく、alpha_0kではないでしょうか?
      仮定の記述が抜けておりました。Kクラスの初期値の対称性より,α0の全ての要素をα0としました。

      >同様に式(135)のN_kに関しても、kのシグマの中に入れるべきではないでしょうか?
      おっしゃる通りです。ディリクレ分布の恒等式まわりの記述に関して,全面的に修正致しました。

      以上,お手隙の際にご確認いただけますでしょうか。ご指摘非常に助かります!

  • 大変分かりやすい記事ありがとうございます.

    一点だけ教えていただけないでしょうか?
    (104)式では尤度の部分に対数は現れないのでしょうか?
    また(107)ではz_nkがlogρにかけられているのに対して,(105)式ではlogπのみにかけられているのも気になります.

    確認いただけますでしょうか?
    EMアルゴリズムから読ませていただいていて,大変勉強になっております.ありがとうございます。

    • はり様

      ご質問と温かいお言葉ありがとうございます。

      >(104)式では尤度の部分に対数は現れないのでしょうか?
      すみません,誤植です。修正致しました。

      >(107)ではz_nkがlogρにかけられているのに対して,(105)式ではlogπのみにかけられているのも気になります.
      恐れ入りますが,式番号がズレてしまったかもしれません。お手数ですが,今一度質問をご確認いただけますでしょうか。

  • 式103について質問です。
    式36などの変分ベイズの更新式を見ると、qi≠jに対する期待値が取れられていますが、式103では、qi≠jに対する期待値が取られていないように見えます。
    こちらはどのようなことが意味されているのでしょうか?

    • 「平均場近似のグループ分け」にて、Zπ,μ,Σがそれぞれ別のグループに属すると仮定しています。変分ベイズの更新式では「自分以外全ての潜在変数・パラメータで仮定した確率モデルの期待値を取ると近似事後分布の形が得られ」ますので,Zの近似事後分布を求めるためにはπ,μ,Σに関する期待値を考えればよい,となります。

  • 度々申し訳ありません。
    式(103)-(104)の式変形について質問です。

    式(67)では、p(X,Z,π,μ,Σ)=p(X|Z,μ,Σ)p(Z|π)p(π)p(μ|Σ)p(Σ)となっています。
    しかし、式(103)-(104)のp(X,Z,π,μ,Σ)の式変形の際に、p(Z|π)、p(X|Z,μ,Σ)以外はなぜ、消えてしまっているのでしょうか?

    • π,μ,Σに関する期待値を考えているからです。例えば,Xに関する期待値E[X]は定数になりますよね。それと同じで,πに関するp(π)の期待値,μに関するp(μ|Σ)の期待値,Σに関するp(Σ)の期待値は定数になります。なお,期待値の線形性は自明としています。

  • 上記2点、理解することができました。
    ご回答いただき、ありがとうございます。

    追加で質問させていただきたいのですが、定数1/Aを、式(109)のように定めている理由としては、「rnkを式(119)のような形で表したいから」、でしょうか?

    • Yu様

      ご質問ありがとうございます。

      >定数1/Aを、式(109)のように定めている理由としては、「rnkを式(119)のような形で表したいから」、でしょうか?
      いえ。確率の定義からq(Z)Zに関する総和は1となることを利用しています。すなわち,下記関係式を利用しています。
      (242)Zq(Z)=AZn=1Nk=1Kρnkznk=1

  • 何度も何度も申し訳ありません、、、
    式(121)の右辺の第四項の式変形の理解に苦戦しております。
    この式変形を理解するために、どのような情報を参照したらよろしいでしょうか?

    • Yu様

      ご質問ありがとうございます。

      >この式変形を理解するために、どのような情報を参照したらよろしいでしょうか?
      すみません,本文に誤植がありましたので修正致しました。併せて導出も含めて補足も書いてみました。

    • 追記、及び修正ありがとうございます。
      現在、(121)の証明を行おうとしておりますが、(3)のようになってしまうため、何点かお伺いしたいことがあります、、、

      1. a~N(m, Σ)の認識でよろしいでしょうか、、、?
      2. Aはどのような行列でしょうか、、、?
      3. 以下の式(3)のように、(121)式の右辺第二項のTr(AΣ)のようにすっきりした式になりませんでした、、、どのようにすればTr(AΣ)が導出できるのでしょうか、、、?

      お手数ですが、ご回答いただけますと幸いです。
      よろしくお願いいたします。

      === (121)の証明 ===
      E[(x-a)T A (x-a)] を以下のように、変形しております。
      ※EはEa,Aのことです、、、

      E[(x-a)T A (x-a)] = E[xT A x - aT A x - xT A a + aT A a + xT x + aT a]
      = xT E[A] x - E[aT] E[A] x - xT E[A] E[a]+ E[aT A a] + xTx + E[aT a]
      ここで、a~N(m, Σ)だとすると、
      = xT E[A] x - mT E[A] x - xT E[A] m + xTx + mT m + E[aT A a]
      = (x-m)T E[A] (x-m) + E[aT A a] (1)

      E[aT A a]について、aT A a = Tr(aT A a)=Tr(A a aT)となるから、
      E[aT A a] = E[Tr(A a aT)] = Tr(E[A a aT])

      ここで、Aは定数行列?なら、E[aT A a] =Tr(A E[a aT]) (2)

      aの共分散行列がΣの場合、
      Σ=E[(a-m)(a-m)T]
      =E[a aT - a mT - m aT + mmT]
      よって、E[a aT] = Σ - E[- a mT - m aT + mmT]
      = Σ + E[a mT + m aT - mmT]
      これを、(2)に代入すると、E[aT A a] = Tr{A (Σ + E[a mT + m aT - mmT])}
      これを、 (1)に代入すると、
      E[(x-a)T A (x-a)] = (x-m)T E[A] (x-m) + Tr{A (Σ + E[a mT + m aT - mmT])} (3)

    • Yu様

      ご質問ありがとうございます。すみません,誤植がありました。丁寧めに解説を追記してみましたので,ご確認お願い致します。

      念の為ご質問にお答えします。
      >1. a~N(m, Σ)の認識でよろしいでしょうか、、、?
      こちらノーテーションを再定義致しました。

      >2. Aはどのような行列でしょうか、、、?
      こちらもノーテーションを再定義致しました。

      >3. 以下の式(3)のように、(121)式の右辺第二項の...
      こちらも再度本文をご確認いただけますでしょうか。

    • ご丁寧に解説いただき、ありがとうございます、、、
      式(128)-(129)において、Eμ[μμT]が、mmT+(βΣ)^-1になっているのですが、なぜ、(βΣ)^-1も現れるのでしょうか、、、?

    • ご質問ありがとうございます。お手数ですが、直後にある式(141)以降の説明をご覧いただけますでしょうか。

  • 再び失礼します。
    このサイトでは分散共分散行列を精度行列として定めてると思いますが、実装の部分でgmm_pdfでは多変量正規分布の引数のcovが精度行列のまま入力してます。これはどういう意図なんでしょうか?

    • はり様

      ご質問ありがとうございます。

      >実装の部分でgmm_pdfでは多変量正規分布の引数のcovが精度行列のまま入力してます。
      いえ、分散共分散行列として入力しております。実際、m_stepにおいて更新式よりW_invを計算し、その逆行列をself.Wに代入しております。ご確認お願い致します。

    • zuka様

      返信ありがとうございます.
      大変恐縮ではございますが,私の認識がどこからずれてしまっているか一つ一つ確認して教えていただけないでしょうか?お手数おかけして,申し訳ございません.

      1. self.Wはウィシャート分布のパラメータである.
      こちらは実装部分のm_stepの説明のところに明記されてあるかと思います..
      2. 分散共分散行列の逆行列に関するガウス分布の共役事前分布はウィシャート分布である.
      こちらは"1. 同時分布の依存関係確認"の箇所に記述があると思います.
      3. ウィシャート分布の期待値(self.nu*self.W)は精度行列が従う分布の期待値である.
      1. と2. から読み取りました.

      zuka様,重ね重ね申し訳ありませんがよろしくお願い致します.

    • はり様

      大変失礼しました。ご指摘の通りです。的確かつ丁寧なご指摘ありがとうございます。ソースコードを修正致しました。こちらで認識合いますでしょうか。なお、ついでにdockerを用いてVBを実行できるように改修しました。お手隙にご確認お願い致します。

      https://github.com/beginaid/GMM-EM-VB

  • 非常に詳細かつ緻密な解説をいただき、感謝いたします。その上で質問です。ただし、VBにおける潜在変数とEMにおける潜在変数をZVBZEMと区別させていただきます。

    平均場近似は、q(ZVB)は独立な分布qi(Zi)の積で表されるという仮定だと認識しております。このノーテーションをもとに考えれば、GMMでは潜在変数をZ1={ZEM},Z2={π,μ,Σ}と分割しておられるのだと思います。しかし、p(ZEM)が定義されず、代わりにp(ZEM|π)が定義されていることからも、ZEMπは独立しているとは言えないように感じます。なぜZ1={ZEM,π},Z2={μ,Σ}としないのでしょうか?

    • うさぎさま

      ご質問ありがとうございます。

      >平均場近似は、q(ZVB)は独立な分布qi(Zi)の積で表されるという仮定だと認識しております。
      はい。自分の理解と合います。

      >GMMでは潜在変数をZ1={ZEM},Z2={π,μ,Σ}と分割しておられる
      こちらは自分の理解と合いません。
      お手数ですが、記事最下部の「付録」をご確認いただけますでしょうか。
      付録をご覧いただいた上でご質問がある場合は、再度コメントいただけますと幸いです。

    • ご返信ありがとうございます。

      > 記事最下部の「付録」をご確認いただけますでしょうか

      「付録」でおっしゃっていることを要約すると、

      - VBにおける潜在変数Zは、EMにおける潜在変数をZ0として、またモデルパラメタθ1,θ2,,θKZ1,Z2,,ZKとして含んでいる
      - VBにおける最適化はZの要素を平均場近似によって排反かつ独立なグループに分割し、各グループにおいて他のグループの値を固定した時の期待値を計算し続ける

      のように書けると思いますが、まず認識として齟齬がありますでしょうか。その上で、式(16)と式(90)の比較から、GMMではこのグループについてZ0=ZEM,Z1={π,μ,Σ}のような対応関係にある(より正確には、Z0=ZEM,Z1=π,Z2={μ,Σ})のだと解釈しております。

      分割がどうであったとしても、潜在変数と他のパラメータ(特にπ)が独立であるという仮定を置きながら、p(Z)の分布を陽に置かず、πに依存した分布、すなわちp(Z|π)しか定義しない、また計算においてもそれしか用いないことが腑に落ちていません。仮にこの2変数が独立であるならば、p(Z)=p(Z|π)が成立し、分布をZのみで表現することが可能ではありませんか?

      理解が足りず大変申し訳ございませんが、何卒ご教授をお願いいたします。

    • うさぎさま

      なるほど。理解しました。

      >潜在変数と他のパラメータ(特にπ)が独立である
      こちらが自分の理解と合っておりません。
      式(90)の直後にその理由を記載してみました。
      私が勘違いしている可能性もございますので,もし疑問が残る場合は再度ご指摘いただけますでしょうか。

  • いつも記事読ませていただいております。
    (105),(107)式の4項目の係数は「-1/2」だと思います。

    • maruさま

      ご指摘ありがとうございます。おっしゃる通りですので本文を修正しました。

  • ある目的があって変分ベイズを勉強し始めたところ、このサイトにたどり着きました。難しい内容ですが、この方法の理解に大変役に立っています。
    以下、いくつか質問です。

    ・(56)の第二項の分数部分はKLダイバージェンスの定義により、q(Z)/p(Z) ではないでしょうか。
    ・(176)、(177)は (175)と同様に、∑(k=1,..,K) の総和の形ではないでしょうか。
    ・(187)の最終行の二つの項の符号は(186)より+になるのではないでしょうか。
    ・(189)の左辺の分子はNk + ν0 - D – 1ではないでしょうか。
    ・(190)の下からの文章ですが、項~に注目して、項~に関する恒等式、などはWに関するではなく∑k に関するものと思いますがいかがでしょうか。
    ・(201)は(200)と同じく∑(n=1,..,N)の総和の形ではないでしょうか。

    • bluestat 様

      ご返信遅れまして大変失礼しました。

      >式(56)について
      対数の差への分解が誤っておりました。

      >式(176)〜式(177)について
      ご指摘の通り,総和が抜けておりました。

      >式(187)について
      ご指摘の通り,符号が誤っておりました。

      >式(189)について
      式(85)のウィシャート分布の定義を利用しております。

      >式(190)について
      ご指摘の通り,Σkに関する整理でした。

      >式(201)について
      ご指摘の通り,総和が抜けておりました。

      以上,修正が完了しました。ご指摘非常に助かりました。

高橋竜眞 へ返信する コメントをキャンセル

※ Please enter your comments in Japanese to distinguish from spam.