Stan:階層モデル①
はじめに
今回は「アヒル本」第8章について整理する。
階層モデルは簡単に説明すると、「説明変数だけでは説明がつかない、グループ(個人)に由来する差異」を扱うための手法。
推定に使えるデータは大量にあるが、グループや個人当たりのデータは少ない場合も多く、それらは階層モデルを用いないとうまく推定できないことがある。
今回の例は「年齢(説明変数)と年収(目的変数)の回帰式の推定」を行う。
1.グループ差を考慮しない
まず、グループ差を考慮しない場合のモデル式を考える。
すなわち、全てのデータを使用した回帰式の推定である。
この場合、全データによる推定結果が得られる。
個人の属性(会社、性別、職種)等が不明の場合、これ以上の推定は出来ない。
ただし、全データをグループ化できるような属性がある場合、そのグループによる差異を考慮してモデリングを検討するべきである。
今回は「グループ差(会社差)」を考慮に入れてモデルを考え直す。
2.グループごとの回帰式
次にグループごとの回帰式を考えてみる。
モデル式としては以下のとおりである。
この場合、下記のような問題が生じる。
- データが少ないグループで過学習が起きる
- 未知のグループの予測が出来ない
3.階層モデル
最後に全てのデータを使って、グループ毎の差もモデルに考慮する。また、グループのデータが少ない場合うまく推定できない問題にも対応する。
各会社の回帰式は「全ての会社の平均+会社の差」に基づくと考える。
モデル式は以下のとおりである。
を推定する場合、「グループの回帰式で推定する」という点は2.と変わらない。しかし、階層モデルの場合、グループの回帰式は、「全体の推定値を平均とした正規分布」に基づいて生成されると考える。
Stanで書いてみる
※アヒル本の「model8-4.stan」を引用
data { int N; int K; real X[N]; real Y[N]; int<lower=1, upper=K> KID[N]; } parameters { real a0; real b0; real a[K]; real b[K]; real<lower=0> s_a; real<lower=0> s_b; real<lower=0> s_Y; } model { for (k in 1:K) { a[k] ~ normal(a0, s_a); b[k] ~ normal(b0, s_b); } for (n in 1:N) Y[n] ~ normal(a[KID[n]] + b[KID[n]]*X[n], s_Y); }
以下のようにしてRから実行する。
d <- read.csv("input/data-salary-2.txt") N <- nrow(d) K <- 4 data <- list(N=N, X=d$X, Y=d$Y, KID=d$KID) fit <- stan(file="model/model8-4.stan", data=data, seed=1234)
推定結果は以下の通り。
> fit Inference for Stan model: model8-4. 4 chains, each with iter=2000; warmup=1000; thin=1; post-warmup draws per chain=1000, total post-warmup draws=4000. mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat a0 397.61 10.77 184.85 181.11 338.71 368.71 417.27 765.48 295 1.01 b0 11.76 0.34 8.58 -3.81 9.09 12.24 15.15 25.44 630 1.01 a[1] 383.78 0.34 14.71 353.84 374.23 383.96 393.89 412.12 1822 1.00 a[2] 334.86 0.54 16.79 302.27 323.48 334.70 345.65 369.15 958 1.00 a[3] 325.02 0.88 32.74 260.15 303.74 325.79 347.78 385.84 1369 1.00 a[4] 489.30 4.69 135.62 308.32 384.63 457.54 568.75 812.74 837 1.00 b[1] 7.72 0.02 0.93 5.98 7.08 7.68 8.33 9.63 1795 1.00 b[2] 19.41 0.04 1.25 16.92 18.57 19.43 20.26 21.78 1054 1.00 b[3] 11.93 0.04 1.61 8.89 10.83 11.91 12.99 15.11 1391 1.00 b[4] 9.48 0.19 5.46 -3.43 6.32 10.66 13.62 16.97 839 1.00 s_a 180.47 14.32 275.11 14.66 52.12 103.31 199.37 793.77 369 1.01 s_b 11.68 0.46 12.45 3.27 5.72 8.14 12.86 43.31 737 1.01 s_Y 28.43 0.09 3.83 22.07 25.73 28.07 30.68 36.93 1918 1.00 lp__ -173.06 0.13 3.35 -180.71 -175.00 -172.71 -170.67 -167.52 676 1.00