機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

Stan:TwoCountryQuiz②

はじめに



前回は「コワイ本」のTwoCountryQuizのタスクに取り組んだ。

今回はこの続きについてまとめる。

ラベルスイッチング



前回のモデルは事後分布が二峰性をもつような分布となっていた。
この原因として「ラベルスイッチング」が原因だと述べた。
混合モデルを考える場合、これはやはり避けて通れない。

はじめに「何故そう考えたか」について尤度計算を交えて説明する。
まず、今回のデータ(TwoCountryQuiz)は以下のような行列だった。

A B ... H
人1 1 0 ... 1
人2 1 0 ... 1
... ... ... ... ...
人8 0 1 ... 0

そして、

  • 人1が「タイ人かどうか」
  • 問題Aが「タイの歴史に関するものかどうか」

を推論することが今回の目的だった。

これは以下の推論でも構わない。

そして、当たり前だがモデルはユーザーがどっちを想定しているかなど知らない。

この推論が同じ結果になることを尤度計算により確認する。

全体尤度の計算は以下の通り。

\begin{eqnarray}
L &=& \prod_{i=1}^{nx}\prod_{j=1}^{zj}\left(
    \begin{array}{ccc}
      (1-p(x_i))\times(1-p(z_j)) \times {\rm Bern}(k_{i,j}|\alpha)\\
      + \\
      p(x_i)\times p(z_j) \times {\rm Bern}(k_{i,j}|\alpha)\\
      + \\
(1-p(x_i))\times p(z_j) \times {\rm Bern}(k_{i,j}|\beta)\\
      + \\
p(x_i)\times  (1-p(z_j)) \times  {\rm Bern}(k_{i,j}|\beta)\\
    \end{array}
  \right)
\end{eqnarray}

対数を取ると以下の通り。


\begin{eqnarray}
\log L &=& \sum_{i=1}^{nx}\sum_{j=1}^{zj} \log\left( 
    \begin{array}{ccc}
      (1-p(x_i))\times (1-p(z_j)) \times  {\rm Bern}(k_{i,j}|\alpha)\\
      + \\
      p(x_i)\times p(z_j) \times  {\rm Bern}(k_{i,j}|\alpha)\\
      + \\
(1-p(x_i))\times p(z_j) \times {\rm Bern}(k_{i,j}|\beta)\\
      + \\
p(x_i)\times  (1-p(z_j)) \times  {\rm Bern}(k_{i,j}|\beta)\\
    \end{array}
  \right)
\end{eqnarray}

log_sum_expを使うと以下の通り。


\begin{eqnarray}
\log L &=& \sum_{i=1}^{nx}\sum_{j=1}^{zj} {\rm log\_sum\_exp}\left( 
    \begin{array}{ccc}
      \log(1-p(x_i))+\log(1-p(z_j)) + \log{\rm Bern}(k_{i,j}|\alpha)\\
      \log p(x_i)+\log p(z_j) +\log {\rm Bern}(k_{i,j}|\alpha)\\
\log(1-p(x_i))+\log p(z_j) +\log{\rm Bern}(k_{i,j}|\beta)\\
\log p(x_i)+ \log(1-p(z_j)) +\log {\rm Bern}(k_{i,j}|\beta)\\
    \end{array}
  \right)
\end{eqnarray}

※各式の意味に関しては前回参照

実際に人1の問題Aに関する尤度を計算する。
尚、計算時点でのパラメータは以下で与えられたとする。

  • \alpha=0.8
  • \beta=0.1
  • x_1=0.2
  • z_1=0.3

計算結果は以下の通り。


\begin{eqnarray}
\log L_{1,1} &=& {\rm log\_sum\_exp}\left( 
    \begin{array}{ccc}
     \log (0.8)+\log (0.7) + \log{\rm Bern}(1|0.8)\\
      \log (0.2)+\log (0.3) + \log{\rm Bern}(1|0.8)\\
\log (0.8)+\log (0.3) + \log{\rm Bern}(1|0.1)\\
\log (0.2)+\log (0.7) + \log{\rm Bern}(1|0.1)\\
    \end{array}
  \right)\\
&=& {\rm log\_sum\_exp}\left(
\begin{array}{ccc}-0.8\\-3.03\\-3.72\\-4.27
\end{array}
\right)\\
&=&-0.62
\end{eqnarray}


他のパラメータを固定したもとでz_1=0.7に更新すると尤度はどう変化するか確認する。
計算結果は以下の通り。


\begin{eqnarray}
\log L_{1,1} &=& {\rm log\_sum\_exp}\left( 
    \begin{array}{ccc}
     \log (0.8)+\log (0.3) + \log{\rm Bern}(1|0.8)\\
      \log (0.2)+\log (0.7) + \log{\rm Bern}(1|0.8)\\
\log (0.8)+\log (0.7) + \log{\rm Bern}(1|0.1)\\
\log (0.2)+\log (0.3) + \log{\rm Bern}(1|0.1)\\
    \end{array}
  \right)\\  &=&{\rm log\_sum\_exp}\left(
\begin{array}{ccc}-1.64\\-2.19\\-2.88\\-5.11
\end{array}
\right)\\
&=&-0.99
\end{eqnarray}

尤度が小さくなることが分かる。
これは、「各人がタイ人(モルドバ人)であるとき、問題がモルドバ(タイ)の歴史」としたの場合の尤度に近い(確率なので同じではない)ので尤度が小さくなる。

では、この条件下でx_1=0.8に更新すると尤度はどう変化するか。


\begin{eqnarray}
\log L_{1,1} &=& {\rm log\_sum\_exp}\left( 
    \begin{array}{ccc}
     \log (0.2)+\log (0.3) + \log{\rm Bern}(1|0.8)\\
      \log (0.8)+\log (0.7) + \log{\rm Bern}(1|0.8)\\
\log (0.2)+\log (0.7) + \log{\rm Bern}(1|0.1)\\
\log (0.8)+\log (0.3) + \log{\rm Bern}(1|0.1)\\
    \end{array}
  \right)\\ &=&  {\rm log\_sum\_exp}\left(
\begin{array}{ccc}-0.8\\-3.03\\-4.27\\-3.72
\end{array}
\right)\\
&=&-0.62
\end{eqnarray}

最初のパラメータでの尤度と等しくなる。
つまり、\alpha,\betaに関係なく、x_1=0.2,z_1=0.3x_1=0.8,z_1=0.7での尤度は等しくなるというわけである。
そんな訳で尤度を大きくするパラメータは、x_1=0.2でもx_1=0.8でも[z_1]によっては尤度が同じということになり、事後分布が二峰性を持つことになるのである。

パラメータの制約

では、どうすればよいか。
問題はx_1=0.2でもx_1=0.8尤度が同じになるということだった。では、x_1が取れる範囲を制約してしまおうというのが1つの対処法である。
例えば、x_1>0.5の制約を設けることで、その制約下で尤度を大きくするためのz_1を探索する。
実装方法としては以下のような書き方が考えられる。

parameters {
  real<lower=0.5, upper=1> init_x_prob;
  real<lower=0, upper=1> other_x_prob[nx-1];
}
transformed parameters {
  real x_prob[nx];
  x_prob[1] = init_x_prob;
  for(i in 1:nx-1)
    x_prob[i+1] = other_x_prob[i];
}

これでx_1の乱数生成を0.5\sim1.0に制限する。
※ベクトルの1要素のみに制約を設定する方法は既にグーグルフォーラムで挙がっていたので参考とさせて頂いた。
Google グループ

モデリング



全体のStanは以下の通り。
x_1の制約以外はほとんど前と同じ。

data{
  int<lower=1> nx;              // num person
  int<lower=1> nz;              // num question
  int<lower=0, upper=1> k[nx, nz];   // person * question
}
parameters {
  real<lower=0.5, upper=1> init_x_prob;
  real<lower=0, upper=1> other_x_prob[nx-1];
  real<lower=0, upper=1> z_prob[nz];
  real<lower=0, upper=1> alpha;
  real<lower=0, upper=alpha> beta;
}
transformed parameters {
  real x_prob[nx];
  x_prob[1] = init_x_prob;
  for(i in 1:nx-1)
    x_prob[i+1] = other_x_prob[i];
}
model {
  vector[4] lp;
  //for(i in 1:nx)
  //  x_prob[i] ~ beta(a,b);;  // prior
  //for(j in 1:nz)
  //  z_prob[j] ~ beta(a,b);;  // prior
  for(i in 1:nx){
    for(j in 1:nz){
      lp[1] = (log1m(x_prob[i]) + log1m(z_prob[j]) + bernoulli_lpmf(k[i, j]| alpha));
      lp[2] = (log(x_prob[i]) + log(z_prob[j]) + bernoulli_lpmf(k[i, j]| alpha));
      lp[3] = (log1m(x_prob[i]) + log(z_prob[j]) + bernoulli_lpmf(k[i, j]| beta));
      lp[4] = (log(x_prob[i]) + log1m(z_prob[j]) + bernoulli_lpmf(k[i, j]| beta));
      target += log_sum_exp(lp);
    }
  }
}

結果の確認



実行結果は以下の通り。

stan.fit <- stan("TwoCountryQuiz.stan",data=stan.dat,pars=c('x_prob','z_prob'))
summary(stan.fit)
$summary
                 mean     se_mean        sd          2.5%          25%         50%
x_prob[1]   0.8257743 0.004583606 0.1372088   0.528647083   0.73226802   0.8620697
x_prob[2]   0.7505101 0.013206512 0.2454967   0.077761085   0.65885906   0.8371625
x_prob[3]   0.2848253 0.011180096 0.2422784   0.008530586   0.09599669   0.2132672
x_prob[4]   0.3344115 0.008289728 0.2467132   0.017356223   0.13546493   0.2774034
x_prob[5]   0.7564982 0.013374523 0.2441466   0.090157994   0.66361202   0.8418284
x_prob[6]   0.7106300 0.011092027 0.2467630   0.087634316   0.57762401   0.7808384
x_prob[7]   0.2776645 0.010963704 0.2489371   0.007548964   0.08553814   0.1958664
x_prob[8]   0.2905203 0.010156586 0.2449326   0.007862108   0.09827746   0.2208900
z_prob[1]   0.7189481 0.010696188 0.2444943   0.085165215   0.60045169   0.7911544
z_prob[2]   0.2712665 0.010862294 0.2398265   0.009096772   0.08682635   0.1975436
z_prob[3]   0.3250752 0.007364386 0.2474678   0.015272428   0.12510105   0.2692455
z_prob[4]   0.7139130 0.011395878 0.2432047   0.099609097   0.59076957   0.7835717
z_prob[5]   0.7556119 0.012485880 0.2441034   0.091712754   0.65056501   0.8450558
z_prob[6]   0.2471072 0.012882019 0.2460234   0.004999654   0.06493085   0.1560054
z_prob[7]   0.2848787 0.012258488 0.2448348   0.010895594   0.09539916   0.2123684
z_prob[8]   0.7588505 0.012929736 0.2370986   0.105125385   0.66719031   0.8418040
lp__      -76.6538982 0.173507422 4.0323283 -85.464497249 -79.14733785 -76.2325885
                  75%       97.5%     n_eff     Rhat
x_prob[1]   0.9405903   0.9950119  896.0852 1.002287
x_prob[2]   0.9309936   0.9919926  345.5532 1.005609
x_prob[3]   0.4083455   0.9058809  469.6110 1.002421
x_prob[4]   0.4877216   0.9151902  885.7355 1.002222
x_prob[5]   0.9374089   0.9953916  333.2304 1.006969
x_prob[6]   0.9068108   0.9915599  494.9237 1.002667
x_prob[7]   0.4008561   0.9103599  515.5425 1.004395
x_prob[8]   0.4202415   0.9004982  581.5642 1.004779
z_prob[1]   0.9100841   0.9919731  522.4915 1.004377
z_prob[2]   0.3818967   0.9177776  487.4739 1.005363
z_prob[3]   0.4726530   0.9227644 1129.1828 1.001313
z_prob[4]   0.9048383   0.9913114  455.4579 1.002746
z_prob[5]   0.9364634   0.9945814  382.2165 1.003059
z_prob[6]   0.3527158   0.9079530  364.7412 1.005107
z_prob[7]   0.4027377   0.9158467  398.9079 1.005976
z_prob[8]   0.9324511   0.9944236  336.2627 1.007949
lp__      -73.6854277 -70.1485710  540.1018 1.002221

収束判定は問題なさそう。
確率も前回のような0.5周辺ではなく、きっちり分かれている。

事後分布

前回、二峰性だった事後分布を確認する。

stan_dens(stan.fit,pars ="x_prob[1]",separate_chains = TRUE)

f:id:kento1109:20180705232218p:plain
単峰性の分布であることが確認できる。chainによるバラツキもあまりない。裾野が長いのが気になるが、これはサンプル数の問題と思われる。(後で検証してみる。)
一応、制約を加えなかったパラメータも確認する。

stan_dens(stan.fit,pars ="x_prob[3]",separate_chains = TRUE)

f:id:kento1109:20180705232758p:plain
二峰性は見られないことが分かる。

トレース結果
stan_trace(stan.fit,pars = "x_prob[1]")

f:id:kento1109:20180705232526p:plain
制約を加えたこともあり、制約範囲内で動いている。(非常に分散が大きいのが少し気にがるが・・)
同様に制約を加えなかったパラメータも確認する。

stan_trace(stan.fit,pars = "x_prob[3]")

f:id:kento1109:20180705232916p:plain
分散は大きいが0\sim0.3あたりで多くサンプリングされていることが分かる。

サンプル数を増やしてモデリング



最後にサンプル数を24人に増やして再度モデリングを行った。
(データセットは単純に8人分のデータ×3しただけ)

結果は以下の通り。

summary(stan.fit)
$summary
                    mean      se_mean         sd          2.5%           25%
x_prob[1]     0.88909723 0.0015264536 0.09654140  6.379189e-01    0.84088896
x_prob[2]     0.88306846 0.0017073929 0.10798501  5.979277e-01    0.83358725
x_prob[3]     0.17779696 0.0022253553 0.14074382  6.284815e-03    0.06514292
x_prob[4]     0.24431969 0.0026401563 0.16697814  1.462065e-02    0.11160276
x_prob[5]     0.88339388 0.0016822210 0.10639300  5.915050e-01    0.83695690
x_prob[6]     0.82474879 0.0021329165 0.13489748  4.874794e-01    0.74488734
x_prob[7]     0.15902617 0.0020105498 0.12715834  5.048530e-03    0.05834003
x_prob[8]     0.17922045 0.0021341862 0.13497779  7.067403e-03    0.07218645
x_prob[9]     0.88388361 0.0016258287 0.10282643  6.155001e-01    0.83528646
x_prob[10]    0.88482958 0.0016178430 0.10232138  6.219439e-01    0.83460776
x_prob[11]    0.17376721 0.0021239378 0.13432962  7.741369e-03    0.07019395
x_prob[12]    0.24658736 0.0025807395 0.16322030  1.360339e-02    0.11582604
x_prob[13]    0.88225775 0.0016771932 0.10607501  6.072548e-01    0.82842426
x_prob[14]    0.82616041 0.0021704359 0.13727042  4.905692e-01    0.74875851
x_prob[15]    0.15963763 0.0021007873 0.13286546  4.750483e-03    0.05315915
x_prob[16]    0.18142896 0.0021570449 0.13642350  6.887512e-03    0.07567072
x_prob[17]    0.88654126 0.0016228331 0.10263698  6.202509e-01    0.83865202
x_prob[18]    0.88467146 0.0016392687 0.10367646  6.029201e-01    0.83315147
x_prob[19]    0.17220191 0.0021321826 0.13485107  6.665817e-03    0.06455919
x_prob[20]    0.24572632 0.0026661863 0.16862443  1.241102e-02    0.11079253
x_prob[21]    0.88259919 0.0016423417 0.10387081  6.201909e-01    0.83208995
x_prob[22]    0.82613928 0.0022048355 0.13944604  4.796073e-01    0.75000947
x_prob[23]    0.15867396 0.0020519952 0.12977957  5.023390e-03    0.05646634
x_prob[24]    0.18281938 0.0021713075 0.13732554  9.281820e-03    0.07586258
z_prob[1]     0.90059791 0.0013291991 0.08406593  6.808949e-01    0.85748083
z_prob[2]     0.08281926 0.0011276441 0.07131847  2.591367e-03    0.02921736
z_prob[3]     0.16589135 0.0018243489 0.11538195  6.090128e-03    0.07509763
z_prob[4]     0.89537937 0.0013233079 0.08369334  6.885263e-01    0.85128247
z_prob[5]     0.94851652 0.0007727115 0.04887056  8.209348e-01    0.92862688
z_prob[6]     0.05198457 0.0007983830 0.05049418  1.322733e-03    0.01514927
z_prob[7]     0.09923228 0.0012919329 0.08170901  3.422783e-03    0.03655729
z_prob[8]     0.94831005 0.0008012587 0.05067605  8.110328e-01    0.92770249
lp__       -165.12638755 0.1230833587 4.69380930 -1.754599e+02 -167.91580233
                     50%          75%        97.5%    n_eff      Rhat
x_prob[1]     0.91570868    0.9632312    0.9965173 4000.000 1.0000139
x_prob[2]     0.91501176    0.9641354    0.9965150 4000.000 0.9997085
x_prob[3]     0.14653449    0.2576420    0.5169342 4000.000 0.9991991
x_prob[4]     0.21428473    0.3508085    0.6205980 4000.000 0.9994822
x_prob[5]     0.91291980    0.9619288    0.9971203 4000.000 0.9994962
x_prob[6]     0.85376265    0.9305850    0.9914166 4000.000 0.9991922
x_prob[7]     0.12972880    0.2282888    0.4739765 4000.000 0.9999588
x_prob[8]     0.15058849    0.2521316    0.5022369 4000.000 0.9998293
x_prob[9]     0.91202499    0.9622890    0.9963306 4000.000 0.9999095
x_prob[10]    0.91357950    0.9628045    0.9967515 4000.000 0.9998023
x_prob[11]    0.14187091    0.2502796    0.4986213 4000.000 0.9994550
x_prob[12]    0.22103501    0.3514679    0.6198173 4000.000 1.0002860
x_prob[13]    0.91473044    0.9630547    0.9967746 4000.000 0.9994394
x_prob[14]    0.85865190    0.9357102    0.9948808 4000.000 0.9998928
x_prob[15]    0.12518733    0.2320583    0.4899430 4000.000 0.9995863
x_prob[16]    0.15374736    0.2578792    0.5127041 4000.000 0.9997273
x_prob[17]    0.91694762    0.9645227    0.9965024 4000.000 0.9995645
x_prob[18]    0.91524153    0.9621940    0.9968299 4000.000 0.9993828
x_prob[19]    0.14001510    0.2524526    0.4938968 4000.000 0.9995974
x_prob[20]    0.21958370    0.3541058    0.6335430 4000.000 0.9998727
x_prob[21]    0.91241451    0.9615494    0.9964914 4000.000 0.9999107
x_prob[22]    0.85743842    0.9350120    0.9944499 4000.000 0.9998915
x_prob[23]    0.12682041    0.2299418    0.4853042 4000.000 1.0002858
x_prob[24]    0.15102426    0.2611727    0.5188826 4000.000 0.9996171
z_prob[1]     0.92263662    0.9645727    0.9961818 4000.000 0.9993239
z_prob[2]     0.06412635    0.1186724    0.2634773 4000.000 0.9995061
z_prob[3]     0.14696709    0.2362540    0.4324484 4000.000 0.9992923
z_prob[4]     0.91543158    0.9591549    0.9955493 4000.000 0.9997276
z_prob[5]     0.96217694    0.9843390    0.9986652 4000.000 0.9999034
z_prob[6]     0.03690516    0.0717256    0.1881421 4000.000 1.0008297
z_prob[7]     0.07866694    0.1412252    0.2983229 4000.000 0.9998847
z_prob[8]     0.96319514    0.9854857    0.9990697 4000.000 0.9995410
lp__       -164.78851211 -161.7427068 -157.0022084 1454.294 1.0035093

全体的に分散が小さくなっていることが分かる。

事後分布

サンプル数を増やしたことで裾野が短くなっていることが分かる。
f:id:kento1109:20180705233724p:plain

トレース結果

トレース結果からも分散が小さくなっていることが分かる。
f:id:kento1109:20180705234112p:plain

さいごに



TwoCountryQuizを取り組んだが、実力不足もあり時間がかかってしまった。しかし、ラベルスイッチングやパラメータの制約に関して理解が深まったので良かった。
次はTwoCountryQuizを応用した別のタスクをを検討したい。