Stan:TwoCountryQuiz②
はじめに
前回は「コワイ本」のTwoCountryQuizのタスクに取り組んだ。
今回はこの続きについてまとめる。
ラベルスイッチング
前回のモデルは事後分布が二峰性をもつような分布となっていた。
この原因として「ラベルスイッチング」が原因だと述べた。
混合モデルを考える場合、これはやはり避けて通れない。
はじめに「何故そう考えたか」について尤度計算を交えて説明する。
まず、今回のデータ(TwoCountryQuiz)は以下のような行列だった。
A | B | ... | H | |
人1 | 1 | 0 | ... | 1 |
人2 | 1 | 0 | ... | 1 |
... | ... | ... | ... | ... |
人8 | 0 | 1 | ... | 0 |
そして、
- 人1が「タイ人かどうか」
- 問題Aが「タイの歴史に関するものかどうか」
を推論することが今回の目的だった。
これは以下の推論でも構わない。
そして、当たり前だがモデルはユーザーがどっちを想定しているかなど知らない。
この推論が同じ結果になることを尤度計算により確認する。
全体尤度の計算は以下の通り。
対数を取ると以下の通り。
log_sum_exp
を使うと以下の通り。
※各式の意味に関しては前回参照
実際に人1の問題Aに関する尤度を計算する。
尚、計算時点でのパラメータは以下で与えられたとする。
計算結果は以下の通り。
他のパラメータを固定したもとでに更新すると尤度はどう変化するか確認する。
計算結果は以下の通り。
尤度が小さくなることが分かる。
これは、「各人がタイ人(モルドバ人)であるとき、問題がモルドバ(タイ)の歴史」としたの場合の尤度に近い(確率なので同じではない)ので尤度が小さくなる。
では、この条件下でに更新すると尤度はどう変化するか。
最初のパラメータでの尤度と等しくなる。
つまり、に関係なく、とでの尤度は等しくなるというわけである。
そんな訳で尤度を大きくするパラメータは、でもでも[z_1]によっては尤度が同じということになり、事後分布が二峰性を持つことになるのである。
パラメータの制約
では、どうすればよいか。
問題はでも尤度が同じになるということだった。では、が取れる範囲を制約してしまおうというのが1つの対処法である。
例えば、の制約を設けることで、その制約下で尤度を大きくするためのを探索する。
実装方法としては以下のような書き方が考えられる。
parameters { real<lower=0.5, upper=1> init_x_prob; real<lower=0, upper=1> other_x_prob[nx-1]; } transformed parameters { real x_prob[nx]; x_prob[1] = init_x_prob; for(i in 1:nx-1) x_prob[i+1] = other_x_prob[i]; }
これでの乱数生成をに制限する。
※ベクトルの1要素のみに制約を設定する方法は既にグーグルフォーラムで挙がっていたので参考とさせて頂いた。
Google グループ
モデリング
全体の
Stan
は以下の通り。の制約以外はほとんど前と同じ。
data{ int<lower=1> nx; // num person int<lower=1> nz; // num question int<lower=0, upper=1> k[nx, nz]; // person * question } parameters { real<lower=0.5, upper=1> init_x_prob; real<lower=0, upper=1> other_x_prob[nx-1]; real<lower=0, upper=1> z_prob[nz]; real<lower=0, upper=1> alpha; real<lower=0, upper=alpha> beta; } transformed parameters { real x_prob[nx]; x_prob[1] = init_x_prob; for(i in 1:nx-1) x_prob[i+1] = other_x_prob[i]; } model { vector[4] lp; //for(i in 1:nx) // x_prob[i] ~ beta(a,b);; // prior //for(j in 1:nz) // z_prob[j] ~ beta(a,b);; // prior for(i in 1:nx){ for(j in 1:nz){ lp[1] = (log1m(x_prob[i]) + log1m(z_prob[j]) + bernoulli_lpmf(k[i, j]| alpha)); lp[2] = (log(x_prob[i]) + log(z_prob[j]) + bernoulli_lpmf(k[i, j]| alpha)); lp[3] = (log1m(x_prob[i]) + log(z_prob[j]) + bernoulli_lpmf(k[i, j]| beta)); lp[4] = (log(x_prob[i]) + log1m(z_prob[j]) + bernoulli_lpmf(k[i, j]| beta)); target += log_sum_exp(lp); } } }
結果の確認
実行結果は以下の通り。
stan.fit <- stan("TwoCountryQuiz.stan",data=stan.dat,pars=c('x_prob','z_prob'))
summary(stan.fit) $summary mean se_mean sd 2.5% 25% 50% x_prob[1] 0.8257743 0.004583606 0.1372088 0.528647083 0.73226802 0.8620697 x_prob[2] 0.7505101 0.013206512 0.2454967 0.077761085 0.65885906 0.8371625 x_prob[3] 0.2848253 0.011180096 0.2422784 0.008530586 0.09599669 0.2132672 x_prob[4] 0.3344115 0.008289728 0.2467132 0.017356223 0.13546493 0.2774034 x_prob[5] 0.7564982 0.013374523 0.2441466 0.090157994 0.66361202 0.8418284 x_prob[6] 0.7106300 0.011092027 0.2467630 0.087634316 0.57762401 0.7808384 x_prob[7] 0.2776645 0.010963704 0.2489371 0.007548964 0.08553814 0.1958664 x_prob[8] 0.2905203 0.010156586 0.2449326 0.007862108 0.09827746 0.2208900 z_prob[1] 0.7189481 0.010696188 0.2444943 0.085165215 0.60045169 0.7911544 z_prob[2] 0.2712665 0.010862294 0.2398265 0.009096772 0.08682635 0.1975436 z_prob[3] 0.3250752 0.007364386 0.2474678 0.015272428 0.12510105 0.2692455 z_prob[4] 0.7139130 0.011395878 0.2432047 0.099609097 0.59076957 0.7835717 z_prob[5] 0.7556119 0.012485880 0.2441034 0.091712754 0.65056501 0.8450558 z_prob[6] 0.2471072 0.012882019 0.2460234 0.004999654 0.06493085 0.1560054 z_prob[7] 0.2848787 0.012258488 0.2448348 0.010895594 0.09539916 0.2123684 z_prob[8] 0.7588505 0.012929736 0.2370986 0.105125385 0.66719031 0.8418040 lp__ -76.6538982 0.173507422 4.0323283 -85.464497249 -79.14733785 -76.2325885 75% 97.5% n_eff Rhat x_prob[1] 0.9405903 0.9950119 896.0852 1.002287 x_prob[2] 0.9309936 0.9919926 345.5532 1.005609 x_prob[3] 0.4083455 0.9058809 469.6110 1.002421 x_prob[4] 0.4877216 0.9151902 885.7355 1.002222 x_prob[5] 0.9374089 0.9953916 333.2304 1.006969 x_prob[6] 0.9068108 0.9915599 494.9237 1.002667 x_prob[7] 0.4008561 0.9103599 515.5425 1.004395 x_prob[8] 0.4202415 0.9004982 581.5642 1.004779 z_prob[1] 0.9100841 0.9919731 522.4915 1.004377 z_prob[2] 0.3818967 0.9177776 487.4739 1.005363 z_prob[3] 0.4726530 0.9227644 1129.1828 1.001313 z_prob[4] 0.9048383 0.9913114 455.4579 1.002746 z_prob[5] 0.9364634 0.9945814 382.2165 1.003059 z_prob[6] 0.3527158 0.9079530 364.7412 1.005107 z_prob[7] 0.4027377 0.9158467 398.9079 1.005976 z_prob[8] 0.9324511 0.9944236 336.2627 1.007949 lp__ -73.6854277 -70.1485710 540.1018 1.002221
収束判定は問題なさそう。
確率も前回のような0.5周辺ではなく、きっちり分かれている。
事後分布
前回、二峰性だった事後分布を確認する。
stan_dens(stan.fit,pars ="x_prob[1]",separate_chains = TRUE)
単峰性の分布であることが確認できる。chainによるバラツキもあまりない。裾野が長いのが気になるが、これはサンプル数の問題と思われる。(後で検証してみる。)
一応、制約を加えなかったパラメータも確認する。
stan_dens(stan.fit,pars ="x_prob[3]",separate_chains = TRUE)
二峰性は見られないことが分かる。
トレース結果
stan_trace(stan.fit,pars = "x_prob[1]")
制約を加えたこともあり、制約範囲内で動いている。(非常に分散が大きいのが少し気にがるが・・)
同様に制約を加えなかったパラメータも確認する。
stan_trace(stan.fit,pars = "x_prob[3]")
分散は大きいがあたりで多くサンプリングされていることが分かる。
サンプル数を増やしてモデリング
最後にサンプル数を24人に増やして再度モデリングを行った。
(データセットは単純に8人分のデータ×3しただけ)
結果は以下の通り。
summary(stan.fit) $summary mean se_mean sd 2.5% 25% x_prob[1] 0.88909723 0.0015264536 0.09654140 6.379189e-01 0.84088896 x_prob[2] 0.88306846 0.0017073929 0.10798501 5.979277e-01 0.83358725 x_prob[3] 0.17779696 0.0022253553 0.14074382 6.284815e-03 0.06514292 x_prob[4] 0.24431969 0.0026401563 0.16697814 1.462065e-02 0.11160276 x_prob[5] 0.88339388 0.0016822210 0.10639300 5.915050e-01 0.83695690 x_prob[6] 0.82474879 0.0021329165 0.13489748 4.874794e-01 0.74488734 x_prob[7] 0.15902617 0.0020105498 0.12715834 5.048530e-03 0.05834003 x_prob[8] 0.17922045 0.0021341862 0.13497779 7.067403e-03 0.07218645 x_prob[9] 0.88388361 0.0016258287 0.10282643 6.155001e-01 0.83528646 x_prob[10] 0.88482958 0.0016178430 0.10232138 6.219439e-01 0.83460776 x_prob[11] 0.17376721 0.0021239378 0.13432962 7.741369e-03 0.07019395 x_prob[12] 0.24658736 0.0025807395 0.16322030 1.360339e-02 0.11582604 x_prob[13] 0.88225775 0.0016771932 0.10607501 6.072548e-01 0.82842426 x_prob[14] 0.82616041 0.0021704359 0.13727042 4.905692e-01 0.74875851 x_prob[15] 0.15963763 0.0021007873 0.13286546 4.750483e-03 0.05315915 x_prob[16] 0.18142896 0.0021570449 0.13642350 6.887512e-03 0.07567072 x_prob[17] 0.88654126 0.0016228331 0.10263698 6.202509e-01 0.83865202 x_prob[18] 0.88467146 0.0016392687 0.10367646 6.029201e-01 0.83315147 x_prob[19] 0.17220191 0.0021321826 0.13485107 6.665817e-03 0.06455919 x_prob[20] 0.24572632 0.0026661863 0.16862443 1.241102e-02 0.11079253 x_prob[21] 0.88259919 0.0016423417 0.10387081 6.201909e-01 0.83208995 x_prob[22] 0.82613928 0.0022048355 0.13944604 4.796073e-01 0.75000947 x_prob[23] 0.15867396 0.0020519952 0.12977957 5.023390e-03 0.05646634 x_prob[24] 0.18281938 0.0021713075 0.13732554 9.281820e-03 0.07586258 z_prob[1] 0.90059791 0.0013291991 0.08406593 6.808949e-01 0.85748083 z_prob[2] 0.08281926 0.0011276441 0.07131847 2.591367e-03 0.02921736 z_prob[3] 0.16589135 0.0018243489 0.11538195 6.090128e-03 0.07509763 z_prob[4] 0.89537937 0.0013233079 0.08369334 6.885263e-01 0.85128247 z_prob[5] 0.94851652 0.0007727115 0.04887056 8.209348e-01 0.92862688 z_prob[6] 0.05198457 0.0007983830 0.05049418 1.322733e-03 0.01514927 z_prob[7] 0.09923228 0.0012919329 0.08170901 3.422783e-03 0.03655729 z_prob[8] 0.94831005 0.0008012587 0.05067605 8.110328e-01 0.92770249 lp__ -165.12638755 0.1230833587 4.69380930 -1.754599e+02 -167.91580233 50% 75% 97.5% n_eff Rhat x_prob[1] 0.91570868 0.9632312 0.9965173 4000.000 1.0000139 x_prob[2] 0.91501176 0.9641354 0.9965150 4000.000 0.9997085 x_prob[3] 0.14653449 0.2576420 0.5169342 4000.000 0.9991991 x_prob[4] 0.21428473 0.3508085 0.6205980 4000.000 0.9994822 x_prob[5] 0.91291980 0.9619288 0.9971203 4000.000 0.9994962 x_prob[6] 0.85376265 0.9305850 0.9914166 4000.000 0.9991922 x_prob[7] 0.12972880 0.2282888 0.4739765 4000.000 0.9999588 x_prob[8] 0.15058849 0.2521316 0.5022369 4000.000 0.9998293 x_prob[9] 0.91202499 0.9622890 0.9963306 4000.000 0.9999095 x_prob[10] 0.91357950 0.9628045 0.9967515 4000.000 0.9998023 x_prob[11] 0.14187091 0.2502796 0.4986213 4000.000 0.9994550 x_prob[12] 0.22103501 0.3514679 0.6198173 4000.000 1.0002860 x_prob[13] 0.91473044 0.9630547 0.9967746 4000.000 0.9994394 x_prob[14] 0.85865190 0.9357102 0.9948808 4000.000 0.9998928 x_prob[15] 0.12518733 0.2320583 0.4899430 4000.000 0.9995863 x_prob[16] 0.15374736 0.2578792 0.5127041 4000.000 0.9997273 x_prob[17] 0.91694762 0.9645227 0.9965024 4000.000 0.9995645 x_prob[18] 0.91524153 0.9621940 0.9968299 4000.000 0.9993828 x_prob[19] 0.14001510 0.2524526 0.4938968 4000.000 0.9995974 x_prob[20] 0.21958370 0.3541058 0.6335430 4000.000 0.9998727 x_prob[21] 0.91241451 0.9615494 0.9964914 4000.000 0.9999107 x_prob[22] 0.85743842 0.9350120 0.9944499 4000.000 0.9998915 x_prob[23] 0.12682041 0.2299418 0.4853042 4000.000 1.0002858 x_prob[24] 0.15102426 0.2611727 0.5188826 4000.000 0.9996171 z_prob[1] 0.92263662 0.9645727 0.9961818 4000.000 0.9993239 z_prob[2] 0.06412635 0.1186724 0.2634773 4000.000 0.9995061 z_prob[3] 0.14696709 0.2362540 0.4324484 4000.000 0.9992923 z_prob[4] 0.91543158 0.9591549 0.9955493 4000.000 0.9997276 z_prob[5] 0.96217694 0.9843390 0.9986652 4000.000 0.9999034 z_prob[6] 0.03690516 0.0717256 0.1881421 4000.000 1.0008297 z_prob[7] 0.07866694 0.1412252 0.2983229 4000.000 0.9998847 z_prob[8] 0.96319514 0.9854857 0.9990697 4000.000 0.9995410 lp__ -164.78851211 -161.7427068 -157.0022084 1454.294 1.0035093
全体的に分散が小さくなっていることが分かる。
事後分布
サンプル数を増やしたことで裾野が短くなっていることが分かる。
トレース結果
トレース結果からも分散が小さくなっていることが分かる。
さいごに
TwoCountryQuizを取り組んだが、実力不足もあり時間がかかってしまった。しかし、ラベルスイッチングやパラメータの制約に関して理解が深まったので良かった。
次はTwoCountryQuizを応用した別のタスクをを検討したい。