機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

TheanoでEmbedding

kerasでは関数が用意されているが、Theanoの場合は自分で定義する必要がある。

kerasについてはこっちでまとめた。
kento1109.hatenablog.com

コードを読んでいても理解は難しくないが、実装する時にもう少し挙動を理解したいと思ったのでメモする。



EmbeddingLayerの定義

下記でEmbeddingの重み行列を作成する。
作成時に必要な情報は、「語彙数・次元数
(下記の例は語彙数=20、次元数=5で作成している。)
この定義は、モデル作成時にのみ記述する。

import numpy as np
import theano
import theano.tensor as T

vocab = 20
embedded_dim = 5

np.set_printoptions(precision=3)
np.random.seed(1)
# embedding initialization
emd = theano.shared(np.random.uniform(-1.0, 1.0, (vocab, embedded_dim)).astype(theano.config.floatX))
print emd.eval()
[[-0.166  0.441 -1.    -0.395 -0.706]
 [-0.815 -0.627 -0.309 -0.206  0.078]
 [-0.162  0.37  -0.591  0.756 -0.945]
 [ 0.341 -0.165  0.117 -0.719 -0.604]
 [ 0.601  0.937 -0.373  0.385  0.753]
 [ 0.789 -0.83  -0.922 -0.66   0.756]
 [-0.803 -0.158  0.916  0.066  0.384]
 [-0.369  0.373  0.669 -0.963  0.5  ]
 [ 0.978  0.496 -0.439  0.579 -0.794]
 [-0.104  0.817 -0.413 -0.424 -0.74 ]
 [-0.961  0.358 -0.577 -0.469 -0.017]
 [-0.893  0.148 -0.707  0.179  0.4  ]
 [-0.795 -0.172  0.389 -0.172 -0.9  ]
 [ 0.072  0.328  0.03   0.889  0.173]
 [ 0.807 -0.725 -0.721  0.615 -0.205]
 [-0.669  0.855 -0.304  0.502  0.452]
 [ 0.767  0.247  0.502 -0.302 -0.46 ]
 [ 0.792 -0.144  0.93   0.327  0.243]
 [-0.771  0.899 -0.1    0.157 -0.184]
 [-0.526  0.807  0.147 -0.994  0.234]]
入力系列の作成

下記のような入力系列を作成しておく。
入力系列は、バッチ単位(or文章単位)で異なる。
(番号は定義済みの単語のインデックスを表す。)
※入力系列単位で入力データの長さは揃える必要がある。

batch_size = 4
sent_length = 10
input_dim = (batch_size, sent_length)
input = np.random.randint(vocab, size=input_dim)
print input
[[16  5 13 17 18  1 10  0  7  0]
 [19 17 14 13 11  6 13 15  9  2]
 [ 7  5  4  5  8 13 17 17 15 13]
 [ 8 14 13 16 10 13  3  2 14 14]]
入力系列をEmbedding

Embedding後は、入力系列+1次元の出力となる。
(今回は入力次元が行列なので、出力は3階テンソル
下記を見ると、先頭の単語ID「16」は[ 0.767 0.247 0.502 -0.302 -0.46 ]に変換されていることが分かる。

output = emd[input]

print output.eval()
[[[ 0.767  0.247  0.502 -0.302 -0.46 ]
  [ 0.789 -0.83  -0.922 -0.66   0.756]
  [ 0.072  0.328  0.03   0.889  0.173]
  [ 0.792 -0.144  0.93   0.327  0.243]
  [-0.771  0.899 -0.1    0.157 -0.184]
  [-0.815 -0.627 -0.309 -0.206  0.078]
  [-0.961  0.358 -0.577 -0.469 -0.017]
  [-0.166  0.441 -1.    -0.395 -0.706]
  [-0.369  0.373  0.669 -0.963  0.5  ]
  [-0.166  0.441 -1.    -0.395 -0.706]]

 [[-0.526  0.807  0.147 -0.994  0.234]
  [ 0.792 -0.144  0.93   0.327  0.243]
  [ 0.807 -0.725 -0.721  0.615 -0.205]
  [ 0.072  0.328  0.03   0.889  0.173]
  [-0.893  0.148 -0.707  0.179  0.4  ]
  [-0.803 -0.158  0.916  0.066  0.384]
  [ 0.072  0.328  0.03   0.889  0.173]
  [-0.669  0.855 -0.304  0.502  0.452]
  [-0.104  0.817 -0.413 -0.424 -0.74 ]
  [-0.162  0.37  -0.591  0.756 -0.945]]

 [[-0.369  0.373  0.669 -0.963  0.5  ]
  [ 0.789 -0.83  -0.922 -0.66   0.756]
  [ 0.601  0.937 -0.373  0.385  0.753]
  [ 0.789 -0.83  -0.922 -0.66   0.756]
  [ 0.978  0.496 -0.439  0.579 -0.794]
  [ 0.072  0.328  0.03   0.889  0.173]
  [ 0.792 -0.144  0.93   0.327  0.243]
  [ 0.792 -0.144  0.93   0.327  0.243]
  [-0.669  0.855 -0.304  0.502  0.452]
  [ 0.072  0.328  0.03   0.889  0.173]]

 [[ 0.978  0.496 -0.439  0.579 -0.794]
  [ 0.807 -0.725 -0.721  0.615 -0.205]
  [ 0.072  0.328  0.03   0.889  0.173]
  [ 0.767  0.247  0.502 -0.302 -0.46 ]
  [-0.961  0.358 -0.577 -0.469 -0.017]
  [ 0.072  0.328  0.03   0.889  0.173]
  [ 0.341 -0.165  0.117 -0.719 -0.604]
  [-0.162  0.37  -0.591  0.756 -0.945]
  [ 0.807 -0.725 -0.721  0.615 -0.205]
  [ 0.807 -0.725 -0.721  0.615 -0.205]]]