機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

Snorkelの識別モデルについて(実装編)

はじめに



前回は生成モデルの構築について確認した。
kento1109.hatenablog.com

尚、識別モデルに関する理論的なことをはこっちにまとめた。
(大したこと書いてないが・・)
Snorkelの識別モデルについて(理論編) - 機械学習・自然言語処理の勉強メモ

生成モデルにデータを入れることで、確率値を出力された。チュートリアルの例で言うと、配偶者の関係の候補となる人物名のペアを入力とすることで、その人物名同士が配偶者である確率が出力された。
f:id:kento1109:20180116143903p:plain
https://hazyresearch.github.io/snorkel/pdfs/snorkel_demo.pdfより引用

識別モデルではこの確率値を正解ラベルとして使用し、訓練を行う。
f:id:kento1109:20180110145722p:plain
Weak Supervisionより引用
つまり、訓練データから0~1の連続値を出力するモデルを構築することを意味する。なので、使用するモデルはロジスティック回帰・SVMs・LSTMなど自由。
なぜ、最終的な予測に識別モデルを使うのか。生成モデルの結果ではダメなのか。識別モデルを用いる理由については、

The discriminative model learns a feature representation of our LFs.
This makes it better able to generalize to unseen candidates.

snorkel/Snorkel-Workshop-FINAL.pdf at master · HazyResearch/snorkel · GitHub
と説明されている。
要は「汎化性能を向上させるための表現学習」として必要だということである。

コードを読んでいく



今回のチュートリアルは以下にある。
github.com

Training a LSTM Neural Network
  • train_cands:訓練データ(候補集合)
  • train_marginals:正解ラベル(生成モデルから出力された確率値)
from snorkel.learning.disc_models.rnn import reRNN

train_kwargs = {
    'lr':         0.001,
    'dim':        100,
    'n_epochs':   10,
    'dropout':    0.25,
    'print_freq': 1,
    'batch_size': 128,
    'max_sentence_length': 100
}

lstm = reRNN(seed=1701, n_threads=1)
lstm.train(train_cands, train_marginals, X_dev=dev_cands, Y_dev=L_gold_dev, **train_kwargs)

今回は真の正解ラベル付きデータ(**_dev)があるので、精度も評価できる。

p, r, f1 = lstm.score(test_cands, L_gold_test)
print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1))
tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)