Pytorch:テキストのバッチ化(BucketIterator)
前回、torchtextに関する基本をまとめた。
今回、もう少し実用的なことをメモする。
BucketIterator
テキストを学習データとする場合、当然、文章の系列長は異なる。文章をバッチ化する場合、パディングして系列長を揃える必要がある。
BucketIteratorを用いると、
- 似た系列長をなるべくまとめて
- 辞書に基づき、単語をインデックス化して、
バッチ化してくれる。
早速、試してみる。
今回は下記のようなテキストを入力とする。(7文)
apple cake dog work pen note desk cat bed book room dance note pen apple body soccer room girl like love apple cat pen room body light book pen desk mother car soccer book battle soccer
まずは、TabularDataset
でデータを読み込む。
from torchtext import data, datasets TEXT = data.Field(sequential=True, use_vocab=True) pos = data.TabularDataset( path=base_path + "/word.txt",format='csv', fields=[('text', TEXT)]) # make vocaburary TEXT.build_vocab(pos)
読み込んだ後は、辞書を作成しておく。
これをしないと後でインデックス化できない。
(辞書がないと怒られる。)
このデータセットをバッチ化する。
train_iter = data.BucketIterator(dataset=pos, batch_size=3, device=-1, repeat=False)
注意点としては、
- CPUの場合、
device
に「-1」を指定 - 繰り返さない場合、
repeat
に「False」を指定
実際に見てみる。
for i, train in enumerate(train_iter): print train.text
Variable containing: 8 3 2 12 13 5 4 16 7 5 22 18 15 2 4 10 10 1 2 9 1 [torch.LongTensor of size 7x3] Variable containing: 5 4 3 17 11 7 19 6 6 20 1 1 3 1 1 8 1 1 [torch.LongTensor of size 6x3] Variable containing: 2 9 21 14 6 [torch.LongTensor of size 5x1]
文章がインデックス化されてバッチ分割されている。
(「1」はpaddingを表すインデックス)
しかし、似た系列長でバッチ化されているわけではない。。
調べた感じ、勝手に似た系列でバッチ化してくれるわけではない。。
以下のようにsort_key
で並び順を明示的に指定することで、その順番でバッチ化されるみたいだ。
train_iter = data.BucketIterator(dataset=pos, batch_size=3, device=-1, repeat=False, sort_key=lambda x: len(x.text))
Variable containing: 2 5 8 5 17 12 7 19 4 18 20 5 4 3 15 1 8 10 1 1 2 [torch.LongTensor of size 7x3] Variable containing: 3 4 2 7 11 9 6 6 21 1 1 14 1 1 6 [torch.LongTensor of size 5x3] Variable containing: 3 13 16 22 2 10 9 [torch.LongTensor of size 7x1]
なるほど勉強になった。
10/12 補足
勉強になったと言いつつも、Iterator
クラスの内部を深く理解していなかったので補足する。特にどこで単語インデックスのバッチオブジェクトを作っているかを押さえる。
まず、
train_iter = data.BucketIterator()
でIterator
インスタンスの生成。
※ここではインスタンスを生成しているだけ
重要なのが、
for train in train_iter: ...
の部分。
イテレータの部分で、Iterator
インスタンスの__iter__()
メソッドが呼ばれる。
では、その一部を見ていく。
def __iter__(self): while True: self.init_epoch() for idx, minibatch in enumerate(self.batches): ...
まずは、init_epoch()
の呼び出し。
この中身を確認する。
def init_epoch(self): """Set up the batch generator for a new epoch.""" if self._restored_from_state: self.random_shuffler.random_state = self._random_state_this_epoch else: self._random_state_this_epoch = self.random_shuffler.random_state self.create_batches() ...
初めに行っているのはcreate_batches()
の呼び出し。
名前からしてこのメソッドが重要な役割を担ってそう。
def create_batches(self): self.batches = batch(self.data(), self.batch_size, self.batch_size_fn)
まず、data()
メソッド。
これは、自身のデータセットを返すだけ。
ただし、ソートやシャッフルが指定されている場合はその処理を施したデータセットを返す。
次に、batch()
メソッド。
def batch(data, batch_size, batch_size_fn=None): """Yield elements from data in chunks of batch_size.""" if batch_size_fn is None: def batch_size_fn(new, count, sofar): return count minibatch, size_so_far = [], 0 for ex in data: minibatch.append(ex) size_so_far = batch_size_fn(ex, len(minibatch), size_so_far) if size_so_far == batch_size: yield minibatch minibatch, size_so_far = [], 0 elif size_so_far > batch_size: yield minibatch[:-1] minibatch, size_so_far = minibatch[-1:], batch_size_fn(ex, 1, 0) if minibatch: yield minibatch
このメソッドはデータセットをミニバッチ化して、イテレーションの度にミニバッチを渡す役割を担う。
現段階で辞書や単語インデックスは登場していない。
さて、__iter__()
メソッドに戻る。
def __iter__(self): while True: self.init_epoch() for idx, minibatch in enumerate(self.batches): .. yield Batch(minibatch, self.dataset, self.device)
最後のyield
まで飛ばす。
結論を言うと、このBatch
クラスが単語のインデックス化を行うクラスだった。
このオブジェクトをみていく。
class Batch(object): def __init__(self, data=None, dataset=None, device=None): ... for (name, field) in dataset.fields.items(): if field is not None: batch = [getattr(x, name) for x in data] setattr(self, name, field.process(batch, device=device))
大事なのは最後のsetattr()
の部分。
setattr()
はオブジェクトに属性を追加を追加する組み込み関数。
詳細はここに分かりやすく書いてある。
setattr属性の追加 - Python学習講座
つまり、バッチオブジェクトのname
(field名)属性にfield.process()
の結果を追加する。
だいぶ目的に近づいてきた。(後二つで答えにたどり着く。)
def process(self, batch, device=None): padded = self.pad(batch) tensor = self.numericalize(padded, device=device) return tensor
さてインデックス化をしてそうなメソッド名がある。numericalize
である。
def numericalize(self, arr, device=None): if self.use_vocab: if self.sequential: arr = [[self.vocab.stoi[x] for x in ex] for ex in arr] else: arr = [self.vocab.stoi[x] for x in arr] var = torch.tensor(arr, dtype=self.dtype, device=device)
※長かったので一部のみ抜粋。
全コードは下記参照
github.com
ようやく目的のvocab.stoi
が見つかった。
ここまで追えればもう少しtorchetxt
を理解して使えそうだ。