機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

Stan:混合モデルとラベルスイッチング

はじめに



前に混合ポアソンモデルについて勉強したが、混合モデルではラベルスイッチングの問題を考える必要がある。

詳しくはユーザーガイド・リファレンスマニュアル「24.2. 混合分布モデルでのラベルスイッチング」参照。

一番簡単な対策方法は、ポアソン分布の平均などパラメータがスカラー値を取る場合、それらに順序をつけて識別するというもの。
これは前にHMMの記事で書いた。
kento1109.hatenablog.com

スカラー値の場合、これで問題ない。では、「2つのパラメータが合計1の確率パラメータベクトル」の場合どうするのかという疑問が生じた。
このベクトル同士で順序をつけることは出来ない。
今回はそんな時にどうすれば良いかを調べてみた。

データの用意



今回は例として以下のような例を考える。

サイコロを用いたギャンブルを考える。ゲストは一定のコインを支払いギャンブルに参加する。ゲストはサイコロの出た目に応じてコインを得る(出た目が大きいほど、得られるコインを多くなる)。

これは普通のギャンブルの例である。今回は更に以下のような状況を考える。

ホストはあまりコインを払い戻したくないので、ゲストにばれないように時々イカサマしたサイコロを使用する。そのサイコロは六面体が小さい数字が出やすくなっており、これを使用することでホストは払い戻しを小さくすることが可能である。しかし、多用するとばれてしまう恐れがあるので、数回に1回程度でイカサマしたサイコロを使用することとする。

どこかの漫画の地下労働施設で聞いたことがある話だ。
さて、今回はこのようなギャンブルデータを利用してモデリングを行う。隠れ状態はホストが使用しているサイコロ(正常・イカサマ)である。

さて、2つのサイコロを投げた時の試行は以下の多項分布に従うとする。(1ゲーム20回の試行とする。)

  • 正常

x=\rm{Multi}(20,[\frac{1}{6},\frac{1}{6},\frac{1}{6},\frac{1}{6},\frac{1}{6},\frac{1}{6}])

x=\rm{Multi}(20,[\frac{1}{3},\frac{1}{3},\frac{1}{12},\frac{1}{12},\frac{1}{12},\frac{1}{12}])

今回のルールに沿って以下のようにデータを作成した。

data1 <- rmultinom(n=20, size = 20, prob = c(1/6,1/6,1/6,1/6,1/6,1/6))
data2 <- rmultinom(n=5, size = 20, prob = c(1/3,1/3,1/12,1/12,1/12,1/12))
data3 <- rmultinom(n=10, size = 20, prob = c(1/6,1/6,1/6,1/6,1/6,1/6))
data4 <- rmultinom(n=5, size = 20, prob = c(1/3,1/3,1/12,1/12,1/12,1/12))

data <- cbind(data1,data2,data3,data4)

今回、40ゲームを行った結果を用いる。

正常なサイコロを1、イカサマを0とした場合の各試行結果をプロットした。
f:id:kento1109:20180626160256p:plain
今回はこれを推論することを目指す。

モデリング

今回は以下のようなコードを書いた。
基本的には前に書いたポアソン混合モデルと同じである。
kento1109.hatenablog.com

data {
  int N;
  int S;
  int K;
  int x[S,N];
}

parameters {
  simplex[S] theta[K];
  simplex[K] pi;
}

transformed parameters{
  vector[K] lp[N];
  for(n in 1:N){
    for (k in 1:K)
      lp[n,k] = log(pi[k]) + multinomial_lpmf(x[,n] | theta[k]);
  }
}

model {
  for(n in 1:N){
    target += log_sum_exp(lp[n]);
  }
}

generated quantities{
  simplex[K] u[N];
  for(n in 1:N)
    u[n] = softmax(lp[n]);
}

generated quantitiesで各クラスの尤度に応じた分類を出力するようにした。

結果①

以下のようにしてRで実行する。

stan.dat <- list(N=ncol(data),S=nrow(data),K=2,x=data)
stan.fit <- stan("multinomial_mixture.stan", data=stan.dat, chain=1)

最初はラベルスイッチングが起きなかった場合の予測を確認するためchain=1とした。

summary(stan.fit)
$summary
                    mean     se_mean         sd          2.5%           25%
theta[1,1]    0.28863139 0.008274501 0.05288002  1.736715e-01  2.552933e-01
theta[1,2]    0.31642265 0.008847883 0.05686860  1.918218e-01  2.838936e-01
theta[1,3]    0.10897124 0.003457797 0.02918702  5.425333e-02  8.951935e-02
theta[1,4]    0.06170202 0.005684556 0.03422223  9.984189e-03  3.689266e-02
theta[1,5]    0.11466797 0.004083461 0.03255207  6.128509e-02  9.070947e-02
theta[1,6]    0.10960474 0.003745194 0.03116486  4.718494e-02  8.925468e-02
theta[2,1]    0.17255238 0.005818873 0.04387844  7.982645e-02  1.542526e-01
theta[2,2]    0.19023815 0.007462860 0.04732092  9.273994e-02  1.731432e-01
theta[2,3]    0.15445216 0.003098625 0.03386714  9.485220e-02  1.400995e-01
theta[2,4]    0.14000995 0.004377188 0.03306693  4.466819e-02  1.297019e-01
theta[2,5]    0.17686862 0.002823444 0.02848967  1.116273e-01  1.640213e-01
theta[2,6]    0.16587875 0.004179644 0.03448394  9.932948e-02  1.509614e-01
                     50%           75%        97.5%     n_eff      Rhat
theta[1,1]  2.924264e-01  3.236019e-01    0.3865116  40.84131 1.0116637
theta[1,2]  3.185399e-01  3.488754e-01    0.4243274  41.31107 1.0120015
theta[1,3]  1.081742e-01  1.289129e-01    0.1649254  71.24926 1.0040223
theta[1,4]  5.672305e-02  8.089627e-02    0.1403246  36.24295 1.0163097
theta[1,5]  1.119642e-01  1.360349e-01    0.1821790  63.54780 1.0133100
theta[1,6]  1.091764e-01  1.300028e-01    0.1745309  69.24394 1.0029883
theta[2,1]  1.712500e-01  1.877133e-01    0.2790961  56.86231 1.0320961
theta[2,2]  1.910106e-01  2.051154e-01    0.3156735  40.20646 1.0344984
theta[2,3]  1.515950e-01  1.645796e-01    0.2385889 119.45905 1.0025754
theta[2,4]  1.415028e-01  1.544405e-01    0.1898349  57.06863 1.0224458
theta[2,5]  1.771090e-01  1.925877e-01    0.2260736 101.81615 1.0134331
theta[2,6]  1.632656e-01  1.771684e-01    0.2335396  68.06991 1.0327917

RHatも問題なさそうだ。

以下のように各試行の推定結果と真のクラスを比較した。

stan.params <- rstan::extract(stan.fit,pars=c("u"))
u.mean <- apply(stan.params$u, c(2,3), mean)
plot(c(rep(1,20),rep(0,5),rep(1,10),rep(0,5)),ylim=c(0,1), ylab="")
par(new=T)
plot(u.mean[,2],ylim=c(0,1),ylab="",col="red")
abline(h=0.5,col="blue")

f:id:kento1109:20180626162114p:plain
全て正しく推定できているのが分かる。

結果②

次にラベルスイッチングを確認するため、chain=4で実行する。

stan.dat <- list(N=ncol(data),S=nrow(data),K=2,x=data)
stan.fit <- stan("multinomial_mixture.stan", data=stan.dat, chain=4)

結果は以下の通り。

summary(stan.fit)
$summary
                    mean    se_mean         sd          2.5%           25%
theta[1,1]    0.24066245 0.03459264 0.07791532  1.283242e-01  1.728793e-01
theta[1,2]    0.26262961 0.03765012 0.08111626  1.352325e-01  1.922499e-01
theta[1,3]    0.12726122 0.01346508 0.03810829  5.826672e-02  9.983748e-02
theta[1,4]    0.09489373 0.02386259 0.05298067  1.092301e-02  4.750612e-02
theta[1,5]    0.14073448 0.01787237 0.04213303  6.460134e-02  1.059021e-01
theta[1,6]    0.13381852 0.01759404 0.04199431  5.904179e-02  1.016729e-01
theta[2,1]    0.22749524 0.03557697 0.07244701  1.342996e-01  1.686509e-01
theta[2,2]    0.24890954 0.03837881 0.07789662  1.391813e-01  1.869382e-01
theta[2,3]    0.13220354 0.01392479 0.03369018  6.264636e-02  1.084989e-01
theta[2,4]    0.10408959 0.02409088 0.04913367  1.668392e-02  5.707473e-02
theta[2,5]    0.14837353 0.01849434 0.04121531  6.799874e-02  1.131823e-01
theta[2,6]    0.13892856 0.01813406 0.03894839  6.006671e-02  1.092404e-01
pi[1]         0.47329259 0.10596077 0.25259796  8.878489e-02  2.454217e-01
pi[2]         0.52670741 0.10596077 0.25259796  1.156611e-01  2.875244e-01
                     50%          75%        97.5%      n_eff     Rhat
theta[1,1]  2.395716e-01    0.3042625    0.3793181   5.073158 1.482184
theta[1,2]  2.680819e-01    0.3265333    0.4125085   4.641765 1.526439
theta[1,3]  1.307148e-01    0.1521753    0.1902564   8.009789 1.224350
theta[1,4]  9.328236e-02    0.1407524    0.1744657   4.929465 1.524555
theta[1,5]  1.441672e-01    0.1753437    0.2103470   5.557524 1.411984
theta[1,6]  1.368468e-01    0.1635550    0.2038121   5.697047 1.334523
theta[2,1]  1.961380e-01    0.2910343    0.3748417   4.146704 1.675218
theta[2,2]  2.161485e-01    0.3167979    0.4003586   4.119592 1.695924
theta[2,3]  1.379141e-01    0.1554845    0.1892626   5.853685 1.347178
theta[2,4]  1.230448e-01    0.1442089    0.1729441   4.159614 1.693136
theta[2,5]  1.588737e-01    0.1801457    0.2117313   4.966378 1.503050
theta[2,6]  1.467403e-01    0.1672130    0.2064530   4.613060 1.485162

RHatが軒並み大きくなってしまっている。

トレース結果は以下の通り。

 stan_trace(stan.fit,pars = "theta")

f:id:kento1109:20180626163254p:plain
chainによって異なるバラつきになっているのが分かる。

結果①と同様に予測結果をプロットする。
f:id:kento1109:20180626163420p:plain
予測がほとんど0.5あたりになっているのが分かる。
(ラベルスイッチングの問題)
次にラベルスイッチングに対応するための方法について書く。

label.switching



調べてみると、Rパッケージにlabel.switchingというのがあった。
CRAN - Package label.switching
この中のアルゴリズムを利用してラベルスイッチング問題を解決できるそうだ。
参考:https://arxiv.org/pdf/1503.02271.pdf

AIC

赤池情報量基準ではなく、「Artificial Identifiability Constraints」。

関数呼び出しは以下のようにして行う。

library("label.switching")
mcmc.params <- rstan::extract(stan.fit)
z <- matrix(ncol=ncol(data),nrow=4000)
for(i in 1:(4000)){
    z[i,] <- apply(mcmc.params$u[i,,],1,which.max)
}

ls <- label.switching(method = "AIC",
                      z = z,
                      K = 2, 
                      mcmc = mcmc.params$theta)

参考:
習作;ラベルスイッチング問題をなんとかする関数 – Kosugitti's BLOG

これでzの40件のデータのラベルを降り直す。

新しいラベルはls$clustersで確認できる。

ls$clusters
    [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13] [,14]
AIC    1    1    1    1    1    1    1    1    1     1     1     1     1     1
    [,15] [,16] [,17] [,18] [,19] [,20] [,21] [,22] [,23] [,24] [,25] [,26] [,27]
AIC     1     1     1     1     1     1     2     2     2     2     2     1     1
    [,28] [,29] [,30] [,31] [,32] [,33] [,34] [,35] [,36] [,37] [,38] [,39] [,40]
AIC     1     1     1     1     1     1     1     1     2     2     2     2     2

再度、正解ラベルと比べてみる。

plot(c(rep(1,20)-.05,rep(0,5)+.05,rep(1,10)-.05,rep(0,5)+.05),ylim=c(0,1), ylab="")
par(new=T)
y_pred <- ifelse(ls$cluster==2, 0, 1)
plot(t(y_pred),ylim=c(0,1),ylab="",col="red")

f:id:kento1109:20180626175023p:plain
全て正しく一致していることが分かる。
※全て完全一致していたので、微妙にラベルの位置をずらした。

最後に



今回、実験的にAIC関数を使ってみた。
今回は使ってみただけで、label.switching関数の様々なアルゴリズムまで理解できていないので、別の機会に学習しようと思う。

あと、今回はラベルスイッチング問題を解決することに一生懸命になったが、

ある意味, 混合成分のインデックス(あるいはラベル)は重要ではありません. 事後予測推測は, 混合成分が識別できなくとも可能です. 例えば, 新しい観測値の対数確率は, 混合成分が識別されるかに依存しません.
(ユーザーガイド・リファレンスマニュアル「24.2. 混合分布モデルでのラベルスイッチング」)

とある。
要はRHatが大きな値だったり、各Chainが異なった値でバラついていた場合、「ラベルスイッチングが起きてしまっており、このモデルは問題がある。何とかラベルスイッチングを回避しないといけない。」とすぐに考えるのではなく、目的に応じて対応策を考えるのが必要だということだ。
実はこれを意識することが一番大事かもしれない・・