TheanoでEmbedding
kerasでは関数が用意されているが、Theanoの場合は自分で定義する必要がある。
kerasについてはこっちでまとめた。
kento1109.hatenablog.com
コードを読んでいても理解は難しくないが、実装する時にもう少し挙動を理解したいと思ったのでメモする。
EmbeddingLayerの定義
下記でEmbeddingの重み行列を作成する。
作成時に必要な情報は、「語彙数・次元数」
(下記の例は語彙数=20、次元数=5で作成している。)
この定義は、モデル作成時にのみ記述する。
import numpy as np import theano import theano.tensor as T vocab = 20 embedded_dim = 5 np.set_printoptions(precision=3) np.random.seed(1) # embedding initialization emd = theano.shared(np.random.uniform(-1.0, 1.0, (vocab, embedded_dim)).astype(theano.config.floatX)) print emd.eval() [[-0.166 0.441 -1. -0.395 -0.706] [-0.815 -0.627 -0.309 -0.206 0.078] [-0.162 0.37 -0.591 0.756 -0.945] [ 0.341 -0.165 0.117 -0.719 -0.604] [ 0.601 0.937 -0.373 0.385 0.753] [ 0.789 -0.83 -0.922 -0.66 0.756] [-0.803 -0.158 0.916 0.066 0.384] [-0.369 0.373 0.669 -0.963 0.5 ] [ 0.978 0.496 -0.439 0.579 -0.794] [-0.104 0.817 -0.413 -0.424 -0.74 ] [-0.961 0.358 -0.577 -0.469 -0.017] [-0.893 0.148 -0.707 0.179 0.4 ] [-0.795 -0.172 0.389 -0.172 -0.9 ] [ 0.072 0.328 0.03 0.889 0.173] [ 0.807 -0.725 -0.721 0.615 -0.205] [-0.669 0.855 -0.304 0.502 0.452] [ 0.767 0.247 0.502 -0.302 -0.46 ] [ 0.792 -0.144 0.93 0.327 0.243] [-0.771 0.899 -0.1 0.157 -0.184] [-0.526 0.807 0.147 -0.994 0.234]]
入力系列の作成
下記のような入力系列を作成しておく。
入力系列は、バッチ単位(or文章単位)で異なる。
(番号は定義済みの単語のインデックスを表す。)
※入力系列単位で入力データの長さは揃える必要がある。
batch_size = 4 sent_length = 10 input_dim = (batch_size, sent_length) input = np.random.randint(vocab, size=input_dim) print input [[16 5 13 17 18 1 10 0 7 0] [19 17 14 13 11 6 13 15 9 2] [ 7 5 4 5 8 13 17 17 15 13] [ 8 14 13 16 10 13 3 2 14 14]]
入力系列をEmbedding
Embedding後は、入力系列+1次元の出力となる。
(今回は入力次元が行列なので、出力は3階テンソル)
下記を見ると、先頭の単語ID「16」は[ 0.767 0.247 0.502 -0.302 -0.46 ]に変換されていることが分かる。
output = emd[input] print output.eval() [[[ 0.767 0.247 0.502 -0.302 -0.46 ] [ 0.789 -0.83 -0.922 -0.66 0.756] [ 0.072 0.328 0.03 0.889 0.173] [ 0.792 -0.144 0.93 0.327 0.243] [-0.771 0.899 -0.1 0.157 -0.184] [-0.815 -0.627 -0.309 -0.206 0.078] [-0.961 0.358 -0.577 -0.469 -0.017] [-0.166 0.441 -1. -0.395 -0.706] [-0.369 0.373 0.669 -0.963 0.5 ] [-0.166 0.441 -1. -0.395 -0.706]] [[-0.526 0.807 0.147 -0.994 0.234] [ 0.792 -0.144 0.93 0.327 0.243] [ 0.807 -0.725 -0.721 0.615 -0.205] [ 0.072 0.328 0.03 0.889 0.173] [-0.893 0.148 -0.707 0.179 0.4 ] [-0.803 -0.158 0.916 0.066 0.384] [ 0.072 0.328 0.03 0.889 0.173] [-0.669 0.855 -0.304 0.502 0.452] [-0.104 0.817 -0.413 -0.424 -0.74 ] [-0.162 0.37 -0.591 0.756 -0.945]] [[-0.369 0.373 0.669 -0.963 0.5 ] [ 0.789 -0.83 -0.922 -0.66 0.756] [ 0.601 0.937 -0.373 0.385 0.753] [ 0.789 -0.83 -0.922 -0.66 0.756] [ 0.978 0.496 -0.439 0.579 -0.794] [ 0.072 0.328 0.03 0.889 0.173] [ 0.792 -0.144 0.93 0.327 0.243] [ 0.792 -0.144 0.93 0.327 0.243] [-0.669 0.855 -0.304 0.502 0.452] [ 0.072 0.328 0.03 0.889 0.173]] [[ 0.978 0.496 -0.439 0.579 -0.794] [ 0.807 -0.725 -0.721 0.615 -0.205] [ 0.072 0.328 0.03 0.889 0.173] [ 0.767 0.247 0.502 -0.302 -0.46 ] [-0.961 0.358 -0.577 -0.469 -0.017] [ 0.072 0.328 0.03 0.889 0.173] [ 0.341 -0.165 0.117 -0.719 -0.604] [-0.162 0.37 -0.591 0.756 -0.945] [ 0.807 -0.725 -0.721 0.615 -0.205] [ 0.807 -0.725 -0.721 0.615 -0.205]]]