機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

(論文)Cloze-driven Pretraining of Self-attention Networks

はじめに



2018年のNLPの主役は「BERT」で間違いないでしょう。
元の論文はGoogleから発表されており、Googleすごいってなりました。
黙っていないのがPytorchを開発した「Facebook」です。

ってことで、彼らの手法でNERのタスクにおいて僅かですがBERTを抜いてSOTAを達成しました。

今日はその論文を紹介したいと思います。
Cloze-driven Pretraining of Self-attention Networks

では、さっそく内容に入っていきます。

ELMoとGPTの流れを汲んだモデルみたいです。
(その点ではBERTと同じですね。)

cloze-style training

モデルの特徴は「bi-directional transformer」ということです。
「cloze-style」(穴埋め形式)の訓練手法という点が特徴的です。
前方(left-to-right)と後方(right-to-left)の文脈情報を与えて、中心の単語を予測します。
f:id:kento1109:20190508160307p:plain

BERTとの違い

気になるのはやはり「BERT」との違いです。
著者も意識していると思いますし、その違いを本文でも以下のように述べています。

我々のモデルは、系列中の各tokenを予測するbi-directional transformer型の言語モデルであり、
BERTもinput全体を利用することで、bi-directionalを実現している。
しかし、実現にはMasked LMやNext Sentence Predictionといった特殊なレジームを必要としている。
その一方で我々は周辺のtokenから各tokenを予測するモデルであり、1つの損失関数で事足りる。

要は学習時に特殊なことしなくてもよいよ、ってことですかね。

Block structure

各Blockの構造は「Transfomer」みたいなものです。
各Blockは2つのsub-blockから構成されます。
1つはH=16の「multi-head self-attention」です。

f:id:kento1109:20190427140223p:plain:h300
https://arxiv.org/pdf/1706.03762.pdf
※予測するtoken以降はマスクしておきます。

もう1つはfeed-forward networkです。

ここまでの入力を整理します。
まず、入力は文の各tokenを分散表現で表した行列です。それにposition embeddingsにより位置情報を足します。

f:id:kento1109:20190427121810p:plain:w300
https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec
更にこの論文では、tokenの文字情報をCNNでエンコードした情報も加えています。

厳密性には欠けますが、入力情報は以下のような行列になるはずです。
f:id:kento1109:20190508170843p:plain:w300

これをforwardとbackwardでそれぞれ共通して利用します。
例えば、forwardでは「I, like」から「playing」単語を予測します。
backwardでは「soccer, ...」から「playing」単語を予測します。
ただ、文中のtokenを予測するモデルなのに、そのtoken以降の情報を含めてはいけません。
そこで、forwardとbackwardでそれぞれマスクをかけます。

Combination of representations

forwardとbackwardのtransformerモデルを組み合わせます。

  • base modelではsum
  • larger modelではconcat

しています。
forwardとbackwardを組み合わせることで、ターゲットtokenの周囲のtoken情報全体を利用することになります。
訓練は各tokenごとに行っているそうです。

Results

このモデルの上にタスク固有のモデルを繋げるそうです。
NERのタスクの場合、biLSTM-CRFを適用させているみたいです。
CoNLL 2003の結果です。
f:id:kento1109:20190509125437p:plain

最後に

最近はRNN系ではなく、transformerをベースとしたモデルが主流のようです。
ただ、NERのタスクで言うとRNN系でも十分に精度は高いと思います。
例えば、以下の論文でのDevセットでのF1値は94.74と今回のSOTAとも2ポイント程度しか違いません。
End-to-end Sequence Labeling via Bi-directional LSTM-CNNs-CRF
もちろん、この2ポイントの違い(RNN系とtransformer系のエラー分析)は気になりますが…