機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

Stan:TwentyQuestions

はじめに



最近、色々なモデリング例を知りたくて「ベイズ統計で実践モデリング(通称コワイ本)」を呼んだ。
まだ途中だが、面白そうな内容があったのでまとめる。
今回、紹介するのは「第六章 潜在混合モデル」の「TwentyQuestions(二十の問題)」というスタディ。

TwentyQuestions


問題の設定としては以下の通り(本文より引用)

10名の人からなるグループがある講義に出席し、そのあとで20の問題に答えたとしよう。どの回答も正答もしくは誤答のいずれかである。正答と誤答のパターンから、2つのことを推論したい。第一は、各人がどれだけしっかりこの講義に注意を向けていたかである。第二は、それぞれの問題がどのくらい難しかったかである。

では実際に取り組む。

データの準備



まずは、データを読み込む。

TwentyQuesitons <- read.csv("TwentyQuestions.txt")
head(TwentyQuesitons)
  A B C D E F G H I J K L M N O P Q R S T
1 1 1 1 1 0 0 1 1 0 1 0 0 1 0 0 1 0 1 0 0
2 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
3 0 0 1 0 0 0 1 1 0 0 0 0 1 0 0 0 0 0 0 0
4 0 0 0 0 0 0 1 0 1 1 0 0 0 0 0 0 0 0 0 0
5 1 0 1 1 0 1 1 1 0 1 0 0 1 0 0 0 0 1 0 0
6 1 1 0 1 0 0 0 1 0 1 0 1 1 0 0 1 0 1 0 0

大したものではないが、ファイルは以下に置いた。
github.com

初めにベイズモデリングによらない場合を考える。
2つの問いに答えるための最もシンプルな回答は各人の正答率と各問題の正答率を計算すれば求まる。
すなわち以下のとおりである。

apply(TwentyQuesitons,1,sum)
 [1] 10  2  4  3  9  9  1  0  7  4
apply(TwentyQuesitons,2,sum)
A B C D E F G H I J K L M N O P Q R S T 
4 4 5 3 0 1 5 5 1 6 0 2 6 0 0 2 0 4 0 1

この結果は後のモデリング推定結果と比較する。

モデリング



今回の推論方法は、「各人は講義のある部分をしっかりと聴講し、各問題はその人が講義で大事なポイントを聞いていた場合にある確率で正しく答えられる」と仮定する。
グラフィカルモデルは以下の通り。

f:id:kento1109:20180703164239p:plain
このモデルのもとで、i番目の人が聴講している確率をp_i、関連する情報を聴いていた場合にj番目の問題に正しく答えれれる確率をq_jとすると、i番目の人がj番目の問題に正しく答える確率は\theta_{ij}=p_iq_jである。i番目の人がj番目の問題に正しく答えられた場合には、k_{ij}=1、そうでない場合にはk_{ij}=0とすると、観測された正答と誤答のパターンは確率\theta_{ij}のベルヌーイ分布から取り出せる。

これをStanで書くと以下のようになる。

data{
  int<lower=1> np;              // num person
  int<lower=1> nq;              // num question
  int<lower=0, upper=1> k[np, nq];   // person * question
  real<lower=0> alpha;          // hyper parameter
  real<lower=0> beta;           // hyper parameter
}
parameters {
  vector<lower=0, upper=1>[np] p;
  vector<lower=0, upper=1>[nq] q;
}
model {
  matrix[np, nq] theta;
  for(i in 1:np)
    p[i] ~ beta(alpha,beta);  // prior
  for(j in 1:nq)
    q[j] ~ beta(alpha,beta);  // prior
  for(i in 1:np){
    for(j in 1:nq){
      theta[i,j] = p[i]*q[j];
    }
  }
  for(i in 1:np){
    for(j in 1:nq){
      k[i,j] ~ bernoulli(theta[i,j]);
    }
  }
}

推論結果



以下のようにデータを用意する。

stan.dat <- list(np=nrow(TwentyQuesitons),nq=ncol(TwentyQuesitons),
                 k=TwentyQuesitons,alpha=1,beta=1)

後は実行するだけ。

stan.fit <- stan("TwentyQuesitons.stan",data=stan.dat)

推定結果は問題なさそう。

summary(stan.fit)
$summary
               mean     se_mean        sd          2.5%           25%
p[1]     0.88911487 0.001552650 0.0981982  6.425905e-01    0.84008660
p[2]     0.27364995 0.002184662 0.1381701  6.462616e-02    0.16920010
p[3]     0.47954574 0.002938769 0.1706163  1.822470e-01    0.35275889
p[4]     0.35393080 0.002368980 0.1498274  1.084504e-01    0.24147944
p[5]     0.84774621 0.001795496 0.1135571  5.704474e-01    0.78241336
p[6]     0.82305815 0.001935901 0.1224372  5.427157e-01    0.74716777
p[7]     0.17879313 0.001785442 0.1129212  2.549268e-02    0.09363857
p[8]     0.09235324 0.001389125 0.0878560  2.172566e-03    0.02654940
p[9]     0.72386666 0.002464854 0.1558911  3.958030e-01    0.61389046
p[10]    0.48147527 0.002690157 0.1701404  1.838145e-01    0.36039881
q[1]     0.73233822 0.002854431 0.1805301  3.339069e-01    0.61163525
q[2]     0.69506571 0.002875385 0.1818553  3.140782e-01    0.56818458
q[3]     0.76328023 0.002492521 0.1576409  4.163806e-01    0.66058258
q[4]     0.64671951 0.003265053 0.2065001  2.354464e-01    0.49608217
q[5]     0.15186711 0.002193721 0.1387431  3.881358e-03    0.04565836
q[6]     0.31198580 0.002879697 0.1821280  4.713141e-02    0.17258446
q[7]     0.74930157 0.002537726 0.1604999  3.925689e-01    0.64327315
q[8]     0.82046104 0.002311769 0.1462091  4.662823e-01    0.73670157
q[9]     0.28239688 0.002653954 0.1678508  4.406978e-02    0.15117084
q[10]    0.85396315 0.001925104 0.1217542  5.405742e-01    0.79064639
q[11]    0.15199359 0.002177869 0.1377405  4.351796e-03    0.04631382
q[12]    0.41510680 0.002960891 0.1872632  1.010399e-01    0.27463613
q[13]    0.85773161 0.001918645 0.1213457  5.475036e-01    0.79335583
q[14]    0.15280130 0.002171380 0.1373301  4.628313e-03    0.04955406
q[15]    0.15410531 0.002183213 0.1380785  4.214024e-03    0.05015196
q[16]    0.47620822 0.003237365 0.2047490  1.254105e-01    0.32009231
q[17]    0.15528100 0.002174545 0.1375303  4.917842e-03    0.05054619
q[18]    0.76074940 0.002701077 0.1708311  3.769623e-01    0.64751266
q[19]    0.15321697 0.002222154 0.1405413  4.094431e-03    0.04752177
q[20]    0.30183513 0.002755366 0.1742647  4.846833e-02    0.16575373
lp__  -136.87215270 0.110364382 4.2668932 -1.462405e+02 -139.48909574
                50%          75%        97.5%    n_eff      Rhat
p[1]     0.91612704    0.9652195    0.9970584 4000.000 0.9994427
p[2]     0.25573221    0.3585345    0.5897540 4000.000 0.9998406
p[3]     0.47188776    0.5943082    0.8313272 3370.623 1.0013353
p[4]     0.33890743    0.4515168    0.6780161 4000.000 0.9997905
p[5]     0.87024789    0.9386957    0.9939227 4000.000 0.9995082
p[6]     0.84294127    0.9201649    0.9903020 4000.000 0.9996904
p[7]     0.15821440    0.2421394    0.4480493 4000.000 1.0011406
p[8]     0.06554715    0.1314936    0.3257131 4000.000 0.9994144
p[9]     0.73622064    0.8464901    0.9717190 4000.000 1.0001526
p[10]    0.47382267    0.5942183    0.8325088 4000.000 1.0007546
q[1]     0.75868716    0.8824943    0.9878256 4000.000 0.9999292
q[2]     0.71215751    0.8385588    0.9805242 4000.000 0.9993718
q[3]     0.78525822    0.8908518    0.9878724 4000.000 1.0006274
q[4]     0.66155809    0.8105196    0.9796591 4000.000 1.0014485
q[5]     0.11103109    0.2182946    0.5075945 4000.000 1.0007001
q[6]     0.28213609    0.4225061    0.7420729 4000.000 0.9995974
q[7]     0.76984773    0.8781185    0.9856990 4000.000 0.9997356
q[8]     0.85699492    0.9378988    0.9946304 4000.000 0.9996302
q[9]     0.25311287    0.3871429    0.6589538 4000.000 0.9992549
q[10]    0.88456150    0.9486784    0.9948674 4000.000 0.9995444
q[11]    0.11334117    0.2191997    0.5034222 4000.000 0.9997808
q[12]    0.39800742    0.5413462    0.8099402 4000.000 0.9994681
q[13]    0.89264225    0.9529320    0.9957055 4000.000 0.9993322
q[14]    0.11854880    0.2134161    0.5220120 4000.000 0.9996678
q[15]    0.11334822    0.2203318    0.5086501 4000.000 0.9990846
q[16]    0.46311487    0.6224283    0.8949644 4000.000 1.0009052
q[17]    0.11769525    0.2218003    0.5102989 4000.000 1.0000498
q[18]    0.79117938    0.9031309    0.9894773 4000.000 0.9997974
q[19]    0.10954860    0.2223453    0.5132886 4000.000 0.9998923
q[20]    0.27322087    0.4119618    0.7153972 4000.000 0.9991965
lp__  -136.48603076 -133.8919920 -129.5536950 1494.740 1.0021107

最も講義に注意を向けていたのはp_1の人で、その確率は「0.889」であった。
また、難しい問題は、q_5,q_{11},q_{14},q_{15},q_{19}などであった。
いずれも最初のシンプルな方法と同じ結果となった。

今回のようなシンプルな推論は次の二国クイズの基礎編として取り上げた。次は二国クイズというタスクを勉強する。