機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

Stan:TwoCountryQuiz①

はじめに



前回は「コワイ本」のTwentyQuestionsを取り組んだ。
kento1109.hatenablog.com

今回はそれをもう少し難しくした「TwoCountryQuiz」に取り組む。
TwentyQuestionsのコードを少し変えるだけと思っていたが、実はとても難しく嵌りどころがたくさんだった・・

それはともかく、さっそく問題を見ていく。

TwoCountryQuiz



問題の設定としては以下の通り(本文より引用)

あるグループの人々が歴史クイズを出されて、各人の各回答は正答または誤答として得点化される。人々の一部はタイ人、一部はモルドバ人である。問題の一部はタイの歴史についてのものであり、モルドバ人よりもタイ人が答えを知ってそうな問題である。残りの問題はモルドバの歴史についてのものであり、タイ人よりもモルドバ人が答えを知っていそうな問題である。
ここでは、誰がタイ人で誰がモルドバ人なのか分からず、問題の内容も分からない。データからどの人が同じ国から来た人なのか、どの問題が彼らの国と関係するのかを推定してほしい。

ちなみにモルドバ人はルーマニア語を話すラテン系の民族集団らしい。
モルドバ人 - Wikipedia

モデルの作成



ここでは2種類の回答を仮定する。
人の国籍が問題の起源と合致する場合、回答率は高い確率で正しい。(タイ人のタイの歴史の問題に対する正答率)
一方、ある人が他の国について尋ねられた場合には、回答は低い確率で正しい。(タイ人のモルドバの歴史の問題に対する正答率)
このアイデアをグラフィカルモデルにすると以下となる。

f:id:kento1109:20180704160348p:plain

この場合、x_i,z_jはそれぞれ国のインデックスを示す。
今回、\theta_{ij}は以下のように定義する。

  \theta_{ij} = \begin{cases}
    \alpha & (x_i=z_j) \\
    \beta & (x_i\ne z_j)
  \end{cases}
正答率\alphaはある人が自身の歴史の問題に対する正答率、正答率\betaはある人が他国の歴史の問題に対する正答率。仮定より\alpha>\betaの制約を設ける。

離散パラメータの対応



WinBUGSの場合、このグラフィカルモデルで問題ない。ただし、そのままStanで扱うことは出来ない。パラメータx_i,z_jは(各国を示すインデックス)が離散値だからである。
離散パラメータに対応する最もシンプルな方法は「周辺化除去(marginalization )」である。詳しくは以前に書いた。

kento1109.hatenablog.com

どう実装するべきかと調べていたら、TwoCountryQuizってどうやってStanで書くのかっていうディスカッションがあった。
groups.google.com

今回の場合は各離散パラメータの値に応じて確率分布が異なる(\alpha\beta)ので、周辺化除去ができない。なので、パラメータx_i,z_jを確率変数として定義するのが妥当なアイデアみたいだ。x_ii番目の人がタイ人(モルドバ人)である確率、z_jは問題jがタイ(モルドバ)の歴史に関する問題の確率を表す。
これでStanでの実装が可能となる。

モデリング



Googleでのディスカッションのアイデアに基づき以下のように実装した。

data{
  int<lower=1> nx;              // num person
  int<lower=1> nz;              // num question
  int<lower=0, upper=1> k[nx, nz];   // person * question
  real<lower=0> a;          // hyper parameter
  real<lower=0> b;           // hyper parameter
}
parameters {
  real<lower=0, upper=1> x_prob[nx];
  real<lower=0, upper=1> z_prob[nz];
  real<lower=0, upper=1> alpha;
  real<lower=0, upper=alpha> beta;
}
model {
  vector[4] lp;
  for(i in 1:nx)
    x_prob[i] ~ beta(a,b);;  // prior
  for(j in 1:nz)
    z_prob[j] ~ beta(a,b);;  // prior
  for(i in 1:nx){
    for(j in 1:nz){
      lp[1] = (log1m(x_prob[i]) + log1m(z_prob[j]) + bernoulli_lpmf(k[i, j]| alpha));
      lp[2] = (log(x_prob[i]) + log(z_prob[j]) + bernoulli_lpmf(k[i, j]| alpha));
      lp[3] = (log1m(x_prob[i]) + log(z_prob[j]) + bernoulli_lpmf(k[i, j]| beta));
      lp[4] = (log(x_prob[i]) + log1m(z_prob[j]) + bernoulli_lpmf(k[i, j]| beta));
      target += log_sum_exp(lp);
    }
  }
}

log1m(a)\log (1-a)を効率的に計算するための関数

ポイントは尤度計算の部分。
今回は人と問題の組み合わせが2×2なので4つの式に基づき尤度が計算される。

k[i, j] タイの歴史 モルドバの歴史
タイ人 lp[1] lp[3]
モルドバ lp[4] lp[2]
組み合わせの計算

なぜこのような尤度計算になるのか簡単な例で考える。

あるサイコロの出目の確率は全て「1/6」のとき、「6」が出る確率は当然「1/6」である。
次にサイコロが2つの場合を考える。
Aのサイコロの「6」が出る確率はp_aであり、Bのサイコロの「6」が出る確率p_bとする。また、Aを投げる確率がp_Aであるとする。
このとき、「6」が出る確率は以下のような計算となる。

\begin{eqnarray}p_A\times p_a+(1-p_A)\times p_b
\end{eqnarray}
このとき、AとBを投げる確率は排反事象であり、計算は「AまたはB」の組み合わせの確率(足し算)で求める。
これにもう一つ条件を加える。
投げ手(ホスト・ゲスト)によってサイコロの出目が変わるとする。ホストが投げる確率をp_H、サイコロと投げ手によるサイコロの「6」が出る確率は以下とする。

ホスト ゲスト
A p_{a,h} p_{a,g}
B p_{b,h} p_{b,g}

このとき、「6」が出る確率は以下のような計算となる。

\begin{eqnarray}\{p_A\times p_H \times p_{a,h}\}+\{
p_A\times (1-p_H) \times p_{a,g}\} +\{
(1-p_A)\times p_H\times p_{b,h} \}+\{
(1-p_A)\times (1-p_H)\times p_{b,g}\}
\end{eqnarray}
これの対数を取ると今回のような計算になることが分かる。

推論結果



以下のようにしてデータを渡す。

尚、実験データは以下に作った。
github.com

TwoCoutryQuiz <- read.csv("TwoCountryQuiz.txt")
stan.dat <- list(nx=nrow(TwoCoutryQuiz),nz=ncol(TwoCoutryQuiz),k=TwoCoutryQuiz,a=1.0,b=1.0)
stan.fit <- stan("TwoCountryQuiz.stan",data=stan.dat)

推論結果は以下の通り。

summary(stan.fit)
$summary
                  mean     se_mean        sd          2.5%          25%         50%
x_prob[1]   0.53029713 0.036775653 0.3541604   0.011127718   0.15609154   0.5985944
x_prob[2]   0.52931333 0.037102322 0.3561591   0.010755941   0.15249905   0.5860287
x_prob[3]   0.47212875 0.030683095 0.3283253   0.011333054   0.15645041   0.4296219
x_prob[4]   0.47621507 0.025471547 0.3018463   0.019701274   0.19461178   0.4615459
x_prob[5]   0.53508411 0.037486953 0.3561731   0.013228531   0.16077255   0.6045322
x_prob[6]   0.52638124 0.031305999 0.3259938   0.022406221   0.20644299   0.5571316
x_prob[7]   0.46897241 0.033495749 0.3356998   0.015722103   0.14365877   0.4235867
x_prob[8]   0.47378394 0.030727943 0.3240059   0.013129741   0.16318500   0.4420749
z_prob[1]   0.52560515 0.031547300 0.3286305   0.018388960   0.20320846   0.5550213
z_prob[2]   0.47342530 0.032388415 0.3323511   0.010483992   0.15403045   0.4412354
z_prob[3]   0.48138335 0.024288168 0.2972166   0.021458533   0.21662389   0.4679963
z_prob[4]   0.52477091 0.030564488 0.3238213   0.021646336   0.20497649   0.5546996
z_prob[5]   0.53043538 0.037612554 0.3594749   0.009185296   0.14711397   0.5980563
z_prob[6]   0.46748793 0.036646445 0.3567124   0.008275019   0.11662093   0.4036751
z_prob[7]   0.46928308 0.031075766 0.3271401   0.012422863   0.15836508   0.4300338
z_prob[8]   0.53520960 0.036506380 0.3559883   0.011017414   0.15899693   0.6020018
alpha       0.84288128 0.004903208 0.1328105   0.495493896   0.77739007   0.8791112
beta        0.08643854 0.003764911 0.1007688   0.001575219   0.01844138   0.0486498
lp__      -76.06556122 0.129462953 3.7709042 -84.645570937 -78.28974038 -75.5811353
                  75%       97.5%     n_eff     Rhat
x_prob[1]   0.8788707   0.9907077  92.74258 1.059581
x_prob[2]   0.8841432   0.9894480  92.14802 1.062257
x_prob[3]   0.7910437   0.9855612 114.50130 1.044380
x_prob[4]   0.7592288   0.9719845 140.43037 1.040285
x_prob[5]   0.8874325   0.9919339  90.27383 1.065438
x_prob[6]   0.8409497   0.9847431 108.43353 1.052919
x_prob[7]   0.8063384   0.9830522 100.44369 1.057658
x_prob[8]   0.7888879   0.9829815 111.18311 1.049953
z_prob[1]   0.8405482   0.9845766 108.51542 1.050264
z_prob[2]   0.8033599   0.9829672 105.29670 1.055173
z_prob[3]   0.7437687   0.9731136 149.74652 1.035735
z_prob[4]   0.8337083   0.9829047 112.24744 1.046241
z_prob[5]   0.8862753   0.9926954  91.34220 1.062050
z_prob[6]   0.8405933   0.9903413  94.74860 1.060763
z_prob[7]   0.7935304   0.9819733 110.82146 1.054304
z_prob[8]   0.8860739   0.9926517  95.08980 1.058714
alpha       0.9451407   0.9944410 733.67550 1.002910
beta        0.1123582   0.3918710 716.37944 1.002677
lp__      -73.4013764 -69.9297720 848.39868 1.002594

これを見る限りRHatは問題なさそうである。
ただし、n_effが小さく見直しの余地あり。

事後分布とサンプルのトレースを確認すればモデルが良くないことが分かる。

事後分布
stan_dens(stan.fit,pars ="x_prob[1]",separate_chains = TRUE)

f:id:kento1109:20180704212745p:plain
ここから分かる問題は

  • 明らかに分布が多峰性となっている
  • chainによる分布のバラつきが大きい。

ということである。
ちなみに理想的な事後分布はこんな感じ。
f:id:kento1109:20180704212849p:plain

トレース結果
stan_trace(stan.fit,pars = "x_prob[1]")

f:id:kento1109:20180704213212p:plain
理想としてはサンプリングがどこかの値付近でバラついてほしい。
しかし、取り得る範囲0~1の間を広く動いている。
(その結果、平均が0.5あたりになるのであろう。)

どう考えてもうまくいっていない・・

なぜこのような問題が起きたのか。
原因はおそらく次の2つだと考えている。

  • データ数の不足
  • ラベルスイッチング

次回、この問題に対する対処方法を考えたい。