Theanoのreshapeについて
reshapeを使ったときに嵌ったのでメモ。
たとえば、numpy
でreshapeを使う場合、
import numpy as np a = np.array([1, 2]) print a.reshape(2, 1) [[1] [2]]
こうすることで、ベクトルとして扱うことが出来る。
同じことをTheanoでやりたい場合、
たとえば、このx
,w
の内積を取りたい場合、
import numpy as np import theano import theano.tensor as T W = theano.shared(np.array([[1, 1.5]], dtype=theano.config.floatX)) x = T.vector('x') y = T.dot(x, W) f = theano.function([x], y) f([1, 2, 3, 4])
この場合、内積するベクトルの要素数が異なるので、エラーとなる。
このようにして揃えてあげたが、
W = theano.shared(np.array([[1, 1.5]], dtype=theano.config.floatX)) x = T.vector('x') y = T.dot(x.reshape(x.shape[0], 1), W)
なぜかreshapeされなかった。
色々調べていたら、()がもう一つ必要であることが分かった。
W = theano.shared(np.array([[1, 1.5]], dtype=theano.config.floatX)) x = T.vector('x') y = T.dot(x.reshape((x.shape[0], 1)), W) [[ 1. 1.5] [ 2. 3. ] [ 3. 4.5] [ 4. 6. ]]
こんなことに悩まないためにメモしておく。。