Stan:TwoCountryQuiz①
はじめに
前回は「コワイ本」のTwentyQuestionsを取り組んだ。
kento1109.hatenablog.com
今回はそれをもう少し難しくした「TwoCountryQuiz」に取り組む。
TwentyQuestionsのコードを少し変えるだけと思っていたが、実はとても難しく嵌りどころがたくさんだった・・
それはともかく、さっそく問題を見ていく。
TwoCountryQuiz
問題の設定としては以下の通り(本文より引用)
あるグループの人々が歴史クイズを出されて、各人の各回答は正答または誤答として得点化される。人々の一部はタイ人、一部はモルドバ人である。問題の一部はタイの歴史についてのものであり、モルドバ人よりもタイ人が答えを知ってそうな問題である。残りの問題はモルドバの歴史についてのものであり、タイ人よりもモルドバ人が答えを知っていそうな問題である。
ここでは、誰がタイ人で誰がモルドバ人なのか分からず、問題の内容も分からない。データからどの人が同じ国から来た人なのか、どの問題が彼らの国と関係するのかを推定してほしい。
ちなみにモルドバ人はルーマニア語を話すラテン系の民族集団らしい。
モルドバ人 - Wikipedia
モデルの作成
ここでは2種類の回答を仮定する。
人の国籍が問題の起源と合致する場合、回答率は高い確率で正しい。(タイ人のタイの歴史の問題に対する正答率)
一方、ある人が他の国について尋ねられた場合には、回答は低い確率で正しい。(タイ人のモルドバの歴史の問題に対する正答率)
このアイデアをグラフィカルモデルにすると以下となる。
この場合、はそれぞれ国のインデックスを示す。
今回、は以下のように定義する。
正答率はある人が自身の歴史の問題に対する正答率、正答率はある人が他国の歴史の問題に対する正答率。仮定よりの制約を設ける。
離散パラメータの対応
WinBUGS
の場合、このグラフィカルモデルで問題ない。ただし、そのままStan
で扱うことは出来ない。パラメータは(各国を示すインデックス)が離散値だからである。離散パラメータに対応する最もシンプルな方法は「周辺化除去(marginalization )」である。詳しくは以前に書いた。
どう実装するべきかと調べていたら、TwoCountryQuizってどうやってStan
で書くのかっていうディスカッションがあった。
groups.google.com
今回の場合は各離散パラメータの値に応じて確率分布が異なる(と)ので、周辺化除去ができない。なので、パラメータを確率変数として定義するのが妥当なアイデアみたいだ。は番目の人がタイ人(モルドバ人)である確率、は問題がタイ(モルドバ)の歴史に関する問題の確率を表す。
これでStan
での実装が可能となる。
モデリング
Googleでのディスカッションのアイデアに基づき以下のように実装した。
data{ int<lower=1> nx; // num person int<lower=1> nz; // num question int<lower=0, upper=1> k[nx, nz]; // person * question real<lower=0> a; // hyper parameter real<lower=0> b; // hyper parameter } parameters { real<lower=0, upper=1> x_prob[nx]; real<lower=0, upper=1> z_prob[nz]; real<lower=0, upper=1> alpha; real<lower=0, upper=alpha> beta; } model { vector[4] lp; for(i in 1:nx) x_prob[i] ~ beta(a,b);; // prior for(j in 1:nz) z_prob[j] ~ beta(a,b);; // prior for(i in 1:nx){ for(j in 1:nz){ lp[1] = (log1m(x_prob[i]) + log1m(z_prob[j]) + bernoulli_lpmf(k[i, j]| alpha)); lp[2] = (log(x_prob[i]) + log(z_prob[j]) + bernoulli_lpmf(k[i, j]| alpha)); lp[3] = (log1m(x_prob[i]) + log(z_prob[j]) + bernoulli_lpmf(k[i, j]| beta)); lp[4] = (log(x_prob[i]) + log1m(z_prob[j]) + bernoulli_lpmf(k[i, j]| beta)); target += log_sum_exp(lp); } } }
※log1m(a)
はを効率的に計算するための関数
ポイントは尤度計算の部分。
今回は人と問題の組み合わせが2×2なので4つの式に基づき尤度が計算される。
k[i, j] |
タイの歴史 | モルドバの歴史 |
タイ人 | lp[1] |
lp[3] |
モルドバ人 | lp[4] |
lp[2] |
組み合わせの計算
なぜこのような尤度計算になるのか簡単な例で考える。
あるサイコロの出目の確率は全て「1/6」のとき、「6」が出る確率は当然「1/6」である。
次にサイコロが2つの場合を考える。
Aのサイコロの「6」が出る確率はであり、Bのサイコロの「6」が出る確率とする。また、Aを投げる確率がであるとする。
このとき、「6」が出る確率は以下のような計算となる。
このとき、AとBを投げる確率は排反事象であり、計算は「AまたはB」の組み合わせの確率(足し算)で求める。
これにもう一つ条件を加える。
投げ手(ホスト・ゲスト)によってサイコロの出目が変わるとする。ホストが投げる確率を、サイコロと投げ手によるサイコロの「6」が出る確率は以下とする。
ホスト | ゲスト | |
A | ||
B |
このとき、「6」が出る確率は以下のような計算となる。
これの対数を取ると今回のような計算になることが分かる。
推論結果
以下のようにしてデータを渡す。
尚、実験データは以下に作った。
github.com
TwoCoutryQuiz <- read.csv("TwoCountryQuiz.txt") stan.dat <- list(nx=nrow(TwoCoutryQuiz),nz=ncol(TwoCoutryQuiz),k=TwoCoutryQuiz,a=1.0,b=1.0) stan.fit <- stan("TwoCountryQuiz.stan",data=stan.dat)
推論結果は以下の通り。
summary(stan.fit) $summary mean se_mean sd 2.5% 25% 50% x_prob[1] 0.53029713 0.036775653 0.3541604 0.011127718 0.15609154 0.5985944 x_prob[2] 0.52931333 0.037102322 0.3561591 0.010755941 0.15249905 0.5860287 x_prob[3] 0.47212875 0.030683095 0.3283253 0.011333054 0.15645041 0.4296219 x_prob[4] 0.47621507 0.025471547 0.3018463 0.019701274 0.19461178 0.4615459 x_prob[5] 0.53508411 0.037486953 0.3561731 0.013228531 0.16077255 0.6045322 x_prob[6] 0.52638124 0.031305999 0.3259938 0.022406221 0.20644299 0.5571316 x_prob[7] 0.46897241 0.033495749 0.3356998 0.015722103 0.14365877 0.4235867 x_prob[8] 0.47378394 0.030727943 0.3240059 0.013129741 0.16318500 0.4420749 z_prob[1] 0.52560515 0.031547300 0.3286305 0.018388960 0.20320846 0.5550213 z_prob[2] 0.47342530 0.032388415 0.3323511 0.010483992 0.15403045 0.4412354 z_prob[3] 0.48138335 0.024288168 0.2972166 0.021458533 0.21662389 0.4679963 z_prob[4] 0.52477091 0.030564488 0.3238213 0.021646336 0.20497649 0.5546996 z_prob[5] 0.53043538 0.037612554 0.3594749 0.009185296 0.14711397 0.5980563 z_prob[6] 0.46748793 0.036646445 0.3567124 0.008275019 0.11662093 0.4036751 z_prob[7] 0.46928308 0.031075766 0.3271401 0.012422863 0.15836508 0.4300338 z_prob[8] 0.53520960 0.036506380 0.3559883 0.011017414 0.15899693 0.6020018 alpha 0.84288128 0.004903208 0.1328105 0.495493896 0.77739007 0.8791112 beta 0.08643854 0.003764911 0.1007688 0.001575219 0.01844138 0.0486498 lp__ -76.06556122 0.129462953 3.7709042 -84.645570937 -78.28974038 -75.5811353 75% 97.5% n_eff Rhat x_prob[1] 0.8788707 0.9907077 92.74258 1.059581 x_prob[2] 0.8841432 0.9894480 92.14802 1.062257 x_prob[3] 0.7910437 0.9855612 114.50130 1.044380 x_prob[4] 0.7592288 0.9719845 140.43037 1.040285 x_prob[5] 0.8874325 0.9919339 90.27383 1.065438 x_prob[6] 0.8409497 0.9847431 108.43353 1.052919 x_prob[7] 0.8063384 0.9830522 100.44369 1.057658 x_prob[8] 0.7888879 0.9829815 111.18311 1.049953 z_prob[1] 0.8405482 0.9845766 108.51542 1.050264 z_prob[2] 0.8033599 0.9829672 105.29670 1.055173 z_prob[3] 0.7437687 0.9731136 149.74652 1.035735 z_prob[4] 0.8337083 0.9829047 112.24744 1.046241 z_prob[5] 0.8862753 0.9926954 91.34220 1.062050 z_prob[6] 0.8405933 0.9903413 94.74860 1.060763 z_prob[7] 0.7935304 0.9819733 110.82146 1.054304 z_prob[8] 0.8860739 0.9926517 95.08980 1.058714 alpha 0.9451407 0.9944410 733.67550 1.002910 beta 0.1123582 0.3918710 716.37944 1.002677 lp__ -73.4013764 -69.9297720 848.39868 1.002594
これを見る限りRHat
は問題なさそうである。
ただし、n_eff
が小さく見直しの余地あり。
事後分布とサンプルのトレースを確認すればモデルが良くないことが分かる。
事後分布
stan_dens(stan.fit,pars ="x_prob[1]",separate_chains = TRUE)
ここから分かる問題は
- 明らかに分布が多峰性となっている
- chainによる分布のバラつきが大きい。
ということである。
ちなみに理想的な事後分布はこんな感じ。
トレース結果
stan_trace(stan.fit,pars = "x_prob[1]")
理想としてはサンプリングがどこかの値付近でバラついてほしい。
しかし、取り得る範囲0~1の間を広く動いている。
(その結果、平均が0.5あたりになるのであろう。)
どう考えてもうまくいっていない・・
なぜこのような問題が起きたのか。
原因はおそらく次の2つだと考えている。
- データ数の不足
- ラベルスイッチング
次回、この問題に対する対処方法を考えたい。