機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

Theanoのreshapeについて

reshapeを使ったときに嵌ったのでメモ。

たとえば、numpyreshapeを使う場合、

import numpy as np
a = np.array([1, 2])
print a.reshape(2, 1)
[[1]
 [2]]

こうすることで、ベクトルとして扱うことが出来る。

同じことをTheanoでやりたい場合、
たとえば、このx,w内積を取りたい場合、

import numpy as np
import theano
import theano.tensor as T

W = theano.shared(np.array([[1, 1.5]], dtype=theano.config.floatX))
x = T.vector('x')
y = T.dot(x, W)
f = theano.function([x], y)
f([1, 2, 3, 4])

この場合、内積するベクトルの要素数が異なるので、エラーとなる。
このようにして揃えてあげたが、

W = theano.shared(np.array([[1, 1.5]], dtype=theano.config.floatX))
x = T.vector('x')
y = T.dot(x.reshape(x.shape[0], 1), W)

なぜかreshapeされなかった。

色々調べていたら、()がもう一つ必要であることが分かった。

W = theano.shared(np.array([[1, 1.5]], dtype=theano.config.floatX))
x = T.vector('x')
y = T.dot(x.reshape((x.shape[0], 1)), W)

[[ 1.   1.5]
 [ 2.   3. ]
 [ 3.   4.5]
 [ 4.   6. ]]

こんなことに悩まないためにメモしておく。。