Stan:混合モデルとラベルスイッチング
はじめに
前に混合ポアソンモデルについて勉強したが、混合モデルではラベルスイッチングの問題を考える必要がある。
詳しくはユーザーガイド・リファレンスマニュアル「24.2. 混合分布モデルでのラベルスイッチング」参照。
一番簡単な対策方法は、ポアソン分布の平均などパラメータがスカラー値を取る場合、それらに順序をつけて識別するというもの。
これは前にHMMの記事で書いた。
kento1109.hatenablog.com
スカラー値の場合、これで問題ない。では、「2つのパラメータが合計1の確率パラメータベクトル」の場合どうするのかという疑問が生じた。
このベクトル同士で順序をつけることは出来ない。
今回はそんな時にどうすれば良いかを調べてみた。
データの用意
今回は例として以下のような例を考える。
サイコロを用いたギャンブルを考える。ゲストは一定のコインを支払いギャンブルに参加する。ゲストはサイコロの出た目に応じてコインを得る(出た目が大きいほど、得られるコインを多くなる)。
これは普通のギャンブルの例である。今回は更に以下のような状況を考える。
ホストはあまりコインを払い戻したくないので、ゲストにばれないように時々イカサマしたサイコロを使用する。そのサイコロは六面体が小さい数字が出やすくなっており、これを使用することでホストは払い戻しを小さくすることが可能である。しかし、多用するとばれてしまう恐れがあるので、数回に1回程度でイカサマしたサイコロを使用することとする。
どこかの漫画の地下労働施設で聞いたことがある話だ。
さて、今回はこのようなギャンブルデータを利用してモデリングを行う。隠れ状態はホストが使用しているサイコロ(正常・イカサマ)である。
さて、2つのサイコロを投げた時の試行は以下の多項分布に従うとする。(1ゲーム20回の試行とする。)
- 正常
- イカサマ
今回のルールに沿って以下のようにデータを作成した。
data1 <- rmultinom(n=20, size = 20, prob = c(1/6,1/6,1/6,1/6,1/6,1/6)) data2 <- rmultinom(n=5, size = 20, prob = c(1/3,1/3,1/12,1/12,1/12,1/12)) data3 <- rmultinom(n=10, size = 20, prob = c(1/6,1/6,1/6,1/6,1/6,1/6)) data4 <- rmultinom(n=5, size = 20, prob = c(1/3,1/3,1/12,1/12,1/12,1/12)) data <- cbind(data1,data2,data3,data4)
今回、40ゲームを行った結果を用いる。
正常なサイコロを1、イカサマを0とした場合の各試行結果をプロットした。
今回はこれを推論することを目指す。
モデリング
今回は以下のようなコードを書いた。
基本的には前に書いたポアソン混合モデルと同じである。
kento1109.hatenablog.com
data { int N; int S; int K; int x[S,N]; } parameters { simplex[S] theta[K]; simplex[K] pi; } transformed parameters{ vector[K] lp[N]; for(n in 1:N){ for (k in 1:K) lp[n,k] = log(pi[k]) + multinomial_lpmf(x[,n] | theta[k]); } } model { for(n in 1:N){ target += log_sum_exp(lp[n]); } } generated quantities{ simplex[K] u[N]; for(n in 1:N) u[n] = softmax(lp[n]); }
generated quantities
で各クラスの尤度に応じた分類を出力するようにした。
結果①
以下のようにしてR
で実行する。
stan.dat <- list(N=ncol(data),S=nrow(data),K=2,x=data) stan.fit <- stan("multinomial_mixture.stan", data=stan.dat, chain=1)
最初はラベルスイッチングが起きなかった場合の予測を確認するためchain=1
とした。
summary(stan.fit) $summary mean se_mean sd 2.5% 25% theta[1,1] 0.28863139 0.008274501 0.05288002 1.736715e-01 2.552933e-01 theta[1,2] 0.31642265 0.008847883 0.05686860 1.918218e-01 2.838936e-01 theta[1,3] 0.10897124 0.003457797 0.02918702 5.425333e-02 8.951935e-02 theta[1,4] 0.06170202 0.005684556 0.03422223 9.984189e-03 3.689266e-02 theta[1,5] 0.11466797 0.004083461 0.03255207 6.128509e-02 9.070947e-02 theta[1,6] 0.10960474 0.003745194 0.03116486 4.718494e-02 8.925468e-02 theta[2,1] 0.17255238 0.005818873 0.04387844 7.982645e-02 1.542526e-01 theta[2,2] 0.19023815 0.007462860 0.04732092 9.273994e-02 1.731432e-01 theta[2,3] 0.15445216 0.003098625 0.03386714 9.485220e-02 1.400995e-01 theta[2,4] 0.14000995 0.004377188 0.03306693 4.466819e-02 1.297019e-01 theta[2,5] 0.17686862 0.002823444 0.02848967 1.116273e-01 1.640213e-01 theta[2,6] 0.16587875 0.004179644 0.03448394 9.932948e-02 1.509614e-01 50% 75% 97.5% n_eff Rhat theta[1,1] 2.924264e-01 3.236019e-01 0.3865116 40.84131 1.0116637 theta[1,2] 3.185399e-01 3.488754e-01 0.4243274 41.31107 1.0120015 theta[1,3] 1.081742e-01 1.289129e-01 0.1649254 71.24926 1.0040223 theta[1,4] 5.672305e-02 8.089627e-02 0.1403246 36.24295 1.0163097 theta[1,5] 1.119642e-01 1.360349e-01 0.1821790 63.54780 1.0133100 theta[1,6] 1.091764e-01 1.300028e-01 0.1745309 69.24394 1.0029883 theta[2,1] 1.712500e-01 1.877133e-01 0.2790961 56.86231 1.0320961 theta[2,2] 1.910106e-01 2.051154e-01 0.3156735 40.20646 1.0344984 theta[2,3] 1.515950e-01 1.645796e-01 0.2385889 119.45905 1.0025754 theta[2,4] 1.415028e-01 1.544405e-01 0.1898349 57.06863 1.0224458 theta[2,5] 1.771090e-01 1.925877e-01 0.2260736 101.81615 1.0134331 theta[2,6] 1.632656e-01 1.771684e-01 0.2335396 68.06991 1.0327917
RHat
も問題なさそうだ。
以下のように各試行の推定結果と真のクラスを比較した。
stan.params <- rstan::extract(stan.fit,pars=c("u")) u.mean <- apply(stan.params$u, c(2,3), mean) plot(c(rep(1,20),rep(0,5),rep(1,10),rep(0,5)),ylim=c(0,1), ylab="") par(new=T) plot(u.mean[,2],ylim=c(0,1),ylab="",col="red") abline(h=0.5,col="blue")
全て正しく推定できているのが分かる。
結果②
次にラベルスイッチングを確認するため、chain=4
で実行する。
stan.dat <- list(N=ncol(data),S=nrow(data),K=2,x=data) stan.fit <- stan("multinomial_mixture.stan", data=stan.dat, chain=4)
結果は以下の通り。
summary(stan.fit) $summary mean se_mean sd 2.5% 25% theta[1,1] 0.24066245 0.03459264 0.07791532 1.283242e-01 1.728793e-01 theta[1,2] 0.26262961 0.03765012 0.08111626 1.352325e-01 1.922499e-01 theta[1,3] 0.12726122 0.01346508 0.03810829 5.826672e-02 9.983748e-02 theta[1,4] 0.09489373 0.02386259 0.05298067 1.092301e-02 4.750612e-02 theta[1,5] 0.14073448 0.01787237 0.04213303 6.460134e-02 1.059021e-01 theta[1,6] 0.13381852 0.01759404 0.04199431 5.904179e-02 1.016729e-01 theta[2,1] 0.22749524 0.03557697 0.07244701 1.342996e-01 1.686509e-01 theta[2,2] 0.24890954 0.03837881 0.07789662 1.391813e-01 1.869382e-01 theta[2,3] 0.13220354 0.01392479 0.03369018 6.264636e-02 1.084989e-01 theta[2,4] 0.10408959 0.02409088 0.04913367 1.668392e-02 5.707473e-02 theta[2,5] 0.14837353 0.01849434 0.04121531 6.799874e-02 1.131823e-01 theta[2,6] 0.13892856 0.01813406 0.03894839 6.006671e-02 1.092404e-01 pi[1] 0.47329259 0.10596077 0.25259796 8.878489e-02 2.454217e-01 pi[2] 0.52670741 0.10596077 0.25259796 1.156611e-01 2.875244e-01 50% 75% 97.5% n_eff Rhat theta[1,1] 2.395716e-01 0.3042625 0.3793181 5.073158 1.482184 theta[1,2] 2.680819e-01 0.3265333 0.4125085 4.641765 1.526439 theta[1,3] 1.307148e-01 0.1521753 0.1902564 8.009789 1.224350 theta[1,4] 9.328236e-02 0.1407524 0.1744657 4.929465 1.524555 theta[1,5] 1.441672e-01 0.1753437 0.2103470 5.557524 1.411984 theta[1,6] 1.368468e-01 0.1635550 0.2038121 5.697047 1.334523 theta[2,1] 1.961380e-01 0.2910343 0.3748417 4.146704 1.675218 theta[2,2] 2.161485e-01 0.3167979 0.4003586 4.119592 1.695924 theta[2,3] 1.379141e-01 0.1554845 0.1892626 5.853685 1.347178 theta[2,4] 1.230448e-01 0.1442089 0.1729441 4.159614 1.693136 theta[2,5] 1.588737e-01 0.1801457 0.2117313 4.966378 1.503050 theta[2,6] 1.467403e-01 0.1672130 0.2064530 4.613060 1.485162
RHat
が軒並み大きくなってしまっている。
トレース結果は以下の通り。
stan_trace(stan.fit,pars = "theta")
chainによって異なるバラつきになっているのが分かる。
結果①と同様に予測結果をプロットする。
予測がほとんど0.5あたりになっているのが分かる。
(ラベルスイッチングの問題)
次にラベルスイッチングに対応するための方法について書く。
label.switching
調べてみると、Rパッケージに
label.switching
というのがあった。CRAN - Package label.switching
この中のアルゴリズムを利用してラベルスイッチング問題を解決できるそうだ。
参考:https://arxiv.org/pdf/1503.02271.pdf
AIC
赤池情報量基準ではなく、「Artificial Identifiability Constraints」。
関数呼び出しは以下のようにして行う。
library("label.switching") mcmc.params <- rstan::extract(stan.fit) z <- matrix(ncol=ncol(data),nrow=4000) for(i in 1:(4000)){ z[i,] <- apply(mcmc.params$u[i,,],1,which.max) } ls <- label.switching(method = "AIC", z = z, K = 2, mcmc = mcmc.params$theta)
参考:
習作;ラベルスイッチング問題をなんとかする関数 – Kosugitti's BLOG
これでz
の40件のデータのラベルを降り直す。
新しいラベルはls$clusters
で確認できる。
ls$clusters [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13] [,14] AIC 1 1 1 1 1 1 1 1 1 1 1 1 1 1 [,15] [,16] [,17] [,18] [,19] [,20] [,21] [,22] [,23] [,24] [,25] [,26] [,27] AIC 1 1 1 1 1 1 2 2 2 2 2 1 1 [,28] [,29] [,30] [,31] [,32] [,33] [,34] [,35] [,36] [,37] [,38] [,39] [,40] AIC 1 1 1 1 1 1 1 1 2 2 2 2 2
再度、正解ラベルと比べてみる。
plot(c(rep(1,20)-.05,rep(0,5)+.05,rep(1,10)-.05,rep(0,5)+.05),ylim=c(0,1), ylab="") par(new=T) y_pred <- ifelse(ls$cluster==2, 0, 1) plot(t(y_pred),ylim=c(0,1),ylab="",col="red")
全て正しく一致していることが分かる。
※全て完全一致していたので、微妙にラベルの位置をずらした。
最後に
今回、実験的に
AIC
関数を使ってみた。今回は使ってみただけで、
label.switching
関数の様々なアルゴリズムまで理解できていないので、別の機会に学習しようと思う。あと、今回はラベルスイッチング問題を解決することに一生懸命になったが、
ある意味, 混合成分のインデックス(あるいはラベル)は重要ではありません. 事後予測推測は, 混合成分が識別できなくとも可能です. 例えば, 新しい観測値の対数確率は, 混合成分が識別されるかに依存しません.
(ユーザーガイド・リファレンスマニュアル「24.2. 混合分布モデルでのラベルスイッチング」)
とある。
要はRHat
が大きな値だったり、各Chainが異なった値でバラついていた場合、「ラベルスイッチングが起きてしまっており、このモデルは問題がある。何とかラベルスイッチングを回避しないといけない。」とすぐに考えるのではなく、目的に応じて対応策を考えるのが必要だということだ。
実はこれを意識することが一番大事かもしれない・・