機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

Stan:階層モデル①

はじめに


今回は「アヒル本」第8章について整理する。
階層モデルは簡単に説明すると、「説明変数だけでは説明がつかない、グループ(個人)に由来する差異」を扱うための手法。

推定に使えるデータは大量にあるが、グループや個人当たりのデータは少ない場合も多く、それらは階層モデルを用いないとうまく推定できないことがある。

今回の例は「年齢x(説明変数)と年収y(目的変数)の回帰式の推定」を行う。

1.グループ差を考慮しない

まず、グループ差を考慮しない場合のモデル式を考える。
すなわち、全てのデータを使用した回帰式の推定である。

\begin{eqnarray}
y [n]  &\sim & {\rm Normal}(a+b x[n],\sigma)
\end{eqnarray}

この場合、全データによる推定結果が得られる。
個人の属性(会社、性別、職種)等が不明の場合、これ以上の推定は出来ない。
ただし、全データをグループ化できるような属性がある場合、そのグループによる差異を考慮してモデリングを検討するべきである。
今回は「グループ差(会社差)」を考慮に入れてモデルを考え直す。

2.グループごとの回帰式

次にグループごとの回帰式を考えてみる。
モデル式としては以下のとおりである。

y [n]  \sim  {\rm Normal}(a[KID[n]]+b[KID[n]] x[n],\sigma)
この場合、下記のような問題が生じる。

  • データが少ないグループで過学習が起きる
  • 未知のグループの予測が出来ない
3.階層モデル

最後に全てのデータを使って、グループ毎の差もモデルに考慮する。また、グループのデータが少ない場合うまく推定できない問題にも対応する。
各会社の回帰式は「全ての会社の平均+会社の差」に基づくと考える。
モデル式は以下のとおりである。

\begin{eqnarray}
y [n]  &\sim & {\rm Normal}(a[KID[n]]+b[KID[n]] x[n],\sigma)\\
a[k] &=&  {\rm Normal}(a_{all}, \sigma)\\
b[k] &=&  {\rm Normal}(b_{all}, \sigma)\\
\end{eqnarray}

y を推定する場合、「グループの回帰式で推定する」という点は2.と変わらない。しかし、階層モデルの場合、グループの回帰式は、「全体の推定値を平均とした正規分布」に基づいて生成されると考える。

Stanで書いてみる

※アヒル本の「model8-4.stan」を引用

data {
  int N;
  int K;
  real X[N];
  real Y[N];
  int<lower=1, upper=K> KID[N];
}

parameters {
  real a0;
  real b0;
  real a[K];
  real b[K];
  real<lower=0> s_a;
  real<lower=0> s_b;
  real<lower=0> s_Y;
}

model {
  for (k in 1:K) {
    a[k] ~ normal(a0, s_a);
    b[k] ~ normal(b0, s_b);
  }

  for (n in 1:N)
    Y[n] ~ normal(a[KID[n]] + b[KID[n]]*X[n], s_Y);
}

以下のようにしてRから実行する。

d <- read.csv("input/data-salary-2.txt")
N <- nrow(d)
K <- 4
data <- list(N=N, X=d$X, Y=d$Y, KID=d$KID)
fit <- stan(file="model/model8-4.stan", data=data, seed=1234)

推定結果は以下の通り。

> fit
Inference for Stan model: model8-4.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

        mean se_mean     sd    2.5%     25%     50%     75%   97.5% n_eff Rhat
a0    397.61   10.77 184.85  181.11  338.71  368.71  417.27  765.48   295 1.01
b0     11.76    0.34   8.58   -3.81    9.09   12.24   15.15   25.44   630 1.01
a[1]  383.78    0.34  14.71  353.84  374.23  383.96  393.89  412.12  1822 1.00
a[2]  334.86    0.54  16.79  302.27  323.48  334.70  345.65  369.15   958 1.00
a[3]  325.02    0.88  32.74  260.15  303.74  325.79  347.78  385.84  1369 1.00
a[4]  489.30    4.69 135.62  308.32  384.63  457.54  568.75  812.74   837 1.00
b[1]    7.72    0.02   0.93    5.98    7.08    7.68    8.33    9.63  1795 1.00
b[2]   19.41    0.04   1.25   16.92   18.57   19.43   20.26   21.78  1054 1.00
b[3]   11.93    0.04   1.61    8.89   10.83   11.91   12.99   15.11  1391 1.00
b[4]    9.48    0.19   5.46   -3.43    6.32   10.66   13.62   16.97   839 1.00
s_a   180.47   14.32 275.11   14.66   52.12  103.31  199.37  793.77   369 1.01
s_b    11.68    0.46  12.45    3.27    5.72    8.14   12.86   43.31   737 1.01
s_Y    28.43    0.09   3.83   22.07   25.73   28.07   30.68   36.93  1918 1.00
lp__ -173.06    0.13   3.35 -180.71 -175.00 -172.71 -170.67 -167.52   676 1.00