機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

(論文)Transformer

久しぶりにブログを更新する。

今日は「Attention Is All You Need」に関する復習。

もはや2年前の論文で、日本語でも丁寧な解説記事がたくさんある。

deeplearning.hatenablog.com

とっても今更感があるが、自分自身の理解の定着のためにまとめようと思う。
とりあえず、従来のRNN型のSeq2Seqモデルと何が違うのか、その辺を押さえたい。
後、理解を深めるために適宜コードも入れればと思う。
※コードに関しては以下サイトより引用させて頂いた。
towardsdatascience.com
nlp.seas.harvard.edu

Transformerモデル



論文中の図を引用する。
f:id:kento1109:20190427111935p:plain:w350

Decoderの入力となるOutputはターゲットは1tokenずつ右にシフトさせておく。
例えば、ターゲット文が「<BOS> 私 は サッカー が 好き <EOS>」の場合、
「<BOS> 私 は サッカー が 好き」としておく。
また、時刻tのtokenから時刻t+1のtokenを予測するためターゲットを左にシフトさせる。
(ターゲットは「私 は サッカー が 好き <EOS>」としておく。)

モデル自体は従来のSeq2Seqモデルと同じ構造。
新しく出現するキーワードは、

  • Positional Encoding
  • Multi-Head Attention
  • Position-wise Feed-Forward Networks

の3つ。

Positional Encoding



入力の位置情報の付与。
センテンスの文脈情報を理解するためには、各単語の順序が重要となる。
このモデルはRNNのように順序を考慮しないので、位置情報を明示的に付与する必要がある。
Positional Encodingは以下の式に基づいて行う。


PE_{(pos,2i)}=sin(pos/10000^{2i/d_{model}}) \\
PE_{(pos,2i+1)}=cos(pos/10000^{2i/d_{model}})

posはセンテンス中の出現位置 0,1,...,Tiは次元数(単語の分散表現の次元数)。

例えば、文の先頭位置に出現する単語のPositional Encodingを考える。
論文の単語の分散表現の次元数は512なので、以下のように先頭位置を表すベクトルは得られる。


PE_{(pos,2i)}=sin(0/10000^{2i/50})

このようして得られた行列をEmbedding後の文の行列に加算することで、各単語と位置(文脈)情報を考慮した情報が得られる。

f:id:kento1109:20190427121810p:plain:w300
https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec

Multi-Head Attention



Multi-Head Attention層は以下のようなイメージ。
f:id:kento1109:20190427140223p:plain:w300

入力を複数のHeadで分割する。

入力は{\rm X} = Q = K = V \in \mathbb{R}^{T\times d_{model}}
各パラメータは W_i^Q\in \mathbb{R}^{d_{model}\times d_k}W_i^K\in \mathbb{R}^{d_{model}\times d_k}W_i^V \in \mathbb{R}^{d_{model}\times d_k}

実験ではd_{model}=512h=8なので,d_k=64としている。
入力を各重みでの線形変換したもの

{\rm Attention}(QW_i^Q , KW_i^K , VW_i^V)
をScaled dot-product attentionの入力としている。

Scaled dot-product attention



この論文でのattentionモデルであり、タイトルにもあるようにここがモデルの肝の部分。
f:id:kento1109:20190427140247p:plain:w200
入力のQKVは先ほど計算したもの。
attentionの計算式は以下の通り。
\begin{eqnarray}{\rm Attention}(Q , K , V)={\rm softmax}(\frac{QK^T}{\sqrt d_k})V\end{eqnarray}

まず、QK^Tの部分。
Q ,K \in\mathbb{R}^{T\times d_k}だったので、QK^T \in \mathbb{R}^{T\times T}となる。
これに\sqrt d_kでスケーリングを行った後、softmaxを適用する。
これにV \in\mathbb{R}^{T\times d_k}を掛け合わせるので、最終的には、
{\rm Attention}(Q , K , V)={\rm softmax}(\frac{QK^T}{\sqrt d_k})V\in \mathbb{R}^{T\times h_d}となる。

これを各head毎に行ったものをConcatして重みW^O \in\mathbb{R}^{hd_v \times d_{model}}で線形変換を行う。

{\rm MultiHead}(Q , K , V)={\rm Concat(head_1, ..., head_h)}W^O
{\rm MultiHead}(Q , K , V)\in\mathbb{R}^{T\times d_{model}}となり、単語数×単語の分散表現の次元数に戻される。

結局のところ、この式変形が何を行っていたかをもう少し深堀してみる。

Query, Key, Value

まず、QK^Tの部分。

f:id:kento1109:20190427151617p:plain

スケーリングを無視して考えると、オレンジのスカラー値はQの赤のベクトルとKの青のベクトルの内積からなることが分かる。
f:id:kento1109:20190427160203p:plain
以下、同様となる。
f:id:kento1109:20190427160342p:plain

要するに、各queryと各keyの類似度を計算していることに相当する。
この行列の0行目のベクトルはquery0と各keyの類似度を表現したベクトルを意味する。
これをsoftmaxで正規化しており、各行の総和が1となるような重みベクトル(attention_weight)を作り出す。

最後にこの行列とV内積を取る。
f:id:kento1109:20190427163122p:plain
各queryのattention_weightとvalue内積を取ったものをattentionの結果として返す。
f:id:kento1109:20190427164937p:plain
なので、最終的な単語ベクトル(オレンジ)は、valueの各ベクトルにattention_weightを掛け合わせた値となることが分かる。
f:id:kento1109:20190427165130p:plain

Encoderでは、query, key, valueは全て同じ入力から渡される。(Self-Attention)
一方、Decoderでは、1層目のMulti-Head Attentionは全て出力から渡されるが、
2層目のQueryは出力から、key, valueはEncoderから渡される。(Source-Target-Attention)
Decoderの2層目のMulti-Head Attentionでは、出力の各queryとEncoderの各keyの類似度を取り、attention_weightを計算する。
それに出力のvalueを掛け合わせることで、Encoderのattention_weightを考慮した計算が可能となる。

もう少し具体例を挙げて考える。
Source文が「I like playing soccer」、Target文が「私/は/サッカー/が/好き/です」とする。
EncoderではSource文でのAttentionを計算する。

次に、Decoderでは、Source文とTarget文でAttentionを計算する。
query, key, valueはそれぞれ以下の通り。
f:id:kento1109:20190509102126p:plain
※Source文からkey, valueが、Target文からQueryが作られる。
attention_weight(QK^T)は以下のように計算される。
f:id:kento1109:20190509102146p:plain

attention_weightの計算を少しだけやってみる。
例えば、attention_weight(0,0)の値は次のように計算できる。
f:id:kento1109:20190509102528p:plain:w300
以下、同様である。
f:id:kento1109:20190509102700p:plain:w300
先頭のattention_weightは、Target文の「私」とSource文の各単語の内積を計算した結果となる。

f:id:kento1109:20190509104005p:plain

このattention_weightとvalueの行列計算を行う。
f:id:kento1109:20190509120816p:plain
attention_weightがこのような値の場合、水色のスカラー値はvalueの「soccer」の値に強く影響される一方、その他の値は0.1でスケーリングされるため、ほとんど影響をうけなくなる。
学習がうまくいくと、以降のネットワーク計算において、最も確率の高い単語として「サッカー」が出力されるようになる。

Mask



ターゲット文が「<BOS> 私 は サッカー が 好き <EOS>」の場合、
「<BOS> 私 は サッカー が 好き」としていた。
これは、時刻tのtokenから時刻t+1のtokenを予測するためであり、
「<BOS> 私 は」の情報から次のtokenが「サッカー」であることを予測していた。
つまり「<BOS> 私 は」の時点で「サッカー」という未来の情報は知るはずがない。
学習時は文として与えられるが、推論時は未来の情報は分からない。
このため、未来の情報にMaskをすることでリークを防ぐ。

Position-wise Feed-Forward Networks

attentionが理解できれば、この層は難しくない。
{\rm FFN}(x)={\rm ReLU}(xW_1+b_1)W_2+b_2

residual connection

ResNetで用いられている構造。

f:id:kento1109:20190427174124p:plain
https://deepage.net/deep_learning/2016/11/30/resnet.html

前の層の入力を足し合わせる。
これは「Positional Encoding & Multi-Head Attention」と「Multi-Head Attention & Position-wise Feed-Forward Networks」で利用されている。
実装例は以下の通り。

class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

これをEncoder内部で呼ぶ。

class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

Positional Encodingの出力がxとなる。

これをEncoder,Decoderで6層ずつ積んでいる。

こんな感じでTransformerモデルが出来る。
今更だが、attentionとか良い勉強になった。