Snorkelの識別モデルについて(実装編)
はじめに
前回は生成モデルの構築について確認した。
kento1109.hatenablog.com
尚、識別モデルに関する理論的なことをはこっちにまとめた。
(大したこと書いてないが・・)
Snorkelの識別モデルについて(理論編) - 機械学習・自然言語処理の勉強メモ
生成モデルにデータを入れることで、確率値を出力された。チュートリアルの例で言うと、配偶者の関係の候補となる人物名のペアを入力とすることで、その人物名同士が配偶者である確率が出力された。
https://hazyresearch.github.io/snorkel/pdfs/snorkel_demo.pdfより引用
識別モデルではこの確率値を正解ラベルとして使用し、訓練を行う。
Weak Supervisionより引用
つまり、訓練データから0~1の連続値を出力するモデルを構築することを意味する。なので、使用するモデルはロジスティック回帰・SVMs・LSTMなど自由。
なぜ、最終的な予測に識別モデルを使うのか。生成モデルの結果ではダメなのか。識別モデルを用いる理由については、
The discriminative model learns a feature representation of our LFs.
This makes it better able to generalize to unseen candidates.
snorkel/Snorkel-Workshop-FINAL.pdf at master · HazyResearch/snorkel · GitHub
と説明されている。
要は「汎化性能を向上させるための表現学習」として必要だということである。
コードを読んでいく
今回のチュートリアルは以下にある。
github.com
Training a LSTM Neural Network
train_cands
:訓練データ(候補集合)train_marginals
:正解ラベル(生成モデルから出力された確率値)
from snorkel.learning.disc_models.rnn import reRNN train_kwargs = { 'lr': 0.001, 'dim': 100, 'n_epochs': 10, 'dropout': 0.25, 'print_freq': 1, 'batch_size': 128, 'max_sentence_length': 100 } lstm = reRNN(seed=1701, n_threads=1) lstm.train(train_cands, train_marginals, X_dev=dev_cands, Y_dev=L_gold_dev, **train_kwargs)
今回は真の正解ラベル付きデータ(**_dev)があるので、精度も評価できる。
p, r, f1 = lstm.score(test_cands, L_gold_test) print("Prec: {0:.3f}, Recall: {1:.3f}, F1 Score: {2:.3f}".format(p, r, f1)) tp, fp, tn, fn = lstm.error_analysis(session, test_cands, L_gold_test)