機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

Pytorch:テキストのバッチ化(BucketIterator)

前回、torchtextに関する基本をまとめた。

kento1109.hatenablog.com

今回、もう少し実用的なことをメモする。

BucketIterator



テキストを学習データとする場合、当然、文章の系列長は異なる。文章をバッチ化する場合、パディングして系列長を揃える必要がある。
BucketIteratorを用いると、

  • 似た系列長をなるべくまとめて
  • 辞書に基づき、単語をインデックス化して、

バッチ化してくれる。

早速、試してみる。
今回は下記のようなテキストを入力とする。(7文)

apple cake dog work pen note desk 
cat bed book room dance note pen
apple body soccer
room girl like love apple cat
pen room body light book
pen desk mother car soccer
book battle soccer

まずは、TabularDatasetでデータを読み込む。

from torchtext import data, datasets

TEXT = data.Field(sequential=True, use_vocab=True)
pos = data.TabularDataset(
        path=base_path + "/word.txt",format='csv',
        fields=[('text', TEXT)])

# make vocaburary
TEXT.build_vocab(pos)

読み込んだ後は、辞書を作成しておく。
これをしないと後でインデックス化できない。
(辞書がないと怒られる。)

このデータセットをバッチ化する。

train_iter = data.BucketIterator(dataset=pos, batch_size=3, device=-1, 
                                 repeat=False)

注意点としては、

  • CPUの場合、deviceに「-1」を指定
  • 繰り返さない場合、repeatに「False」を指定

実際に見てみる。

for i, train in enumerate(train_iter):
    print train.text
Variable containing:
    8     3     2
   12    13     5
    4    16     7
    5    22    18
   15     2     4
   10    10     1
    2     9     1
[torch.LongTensor of size 7x3]

Variable containing:
  5   4   3
 17  11   7
 19   6   6
 20   1   1
  3   1   1
  8   1   1
[torch.LongTensor of size 6x3]

Variable containing:
  2
  9
 21
 14
  6
[torch.LongTensor of size 5x1]

文章がインデックス化されてバッチ分割されている。
(「1」はpaddingを表すインデックス)
しかし、似た系列長でバッチ化されているわけではない。。

調べた感じ、勝手に似た系列でバッチ化してくれるわけではない。。
以下のようにsort_keyで並び順を明示的に指定することで、その順番でバッチ化されるみたいだ。

train_iter = data.BucketIterator(dataset=pos, batch_size=3, device=-1, 
                                 repeat=False, sort_key=lambda x: len(x.text))
Variable containing:
    2     5     8
    5    17    12
    7    19     4
   18    20     5
    4     3    15
    1     8    10
    1     1     2
[torch.LongTensor of size 7x3]

Variable containing:
  3   4   2
  7  11   9
  6   6  21
  1   1  14
  1   1   6
[torch.LongTensor of size 5x3]

Variable containing:
    3
   13
   16
   22
    2
   10
    9
[torch.LongTensor of size 7x1]

なるほど勉強になった。

10/12 補足

勉強になったと言いつつも、Iteratorクラスの内部を深く理解していなかったので補足する。特にどこで単語インデックスのバッチオブジェクトを作っているかを押さえる。

まず、

train_iter = data.BucketIterator()

Iteratorインスタンスの生成。
※ここではインスタンスを生成しているだけ

重要なのが、

for train in train_iter:
    ...

の部分。
イテレータの部分で、Iteratorインスタンス__iter__()メソッドが呼ばれる。
では、その一部を見ていく。

def __iter__(self):
    while True:
        self.init_epoch()
        for idx, minibatch in enumerate(self.batches):
            ...

まずは、init_epoch()の呼び出し。
この中身を確認する。

def init_epoch(self):
    """Set up the batch generator for a new epoch."""

    if self._restored_from_state:
        self.random_shuffler.random_state = self._random_state_this_epoch
    else:
        self._random_state_this_epoch = self.random_shuffler.random_state

    self.create_batches()
    ...

初めに行っているのはcreate_batches()の呼び出し。
名前からしてこのメソッドが重要な役割を担ってそう。

def create_batches(self):
    self.batches = batch(self.data(), self.batch_size, self.batch_size_fn)

まず、data()メソッド。
これは、自身のデータセットを返すだけ。
ただし、ソートやシャッフルが指定されている場合はその処理を施したデータセットを返す。
次に、batch()メソッド。

def batch(data, batch_size, batch_size_fn=None):
    """Yield elements from data in chunks of batch_size."""
    if batch_size_fn is None:
        def batch_size_fn(new, count, sofar):
            return count
    minibatch, size_so_far = [], 0
    for ex in data:
        minibatch.append(ex)
        size_so_far = batch_size_fn(ex, len(minibatch), size_so_far)
        if size_so_far == batch_size:
            yield minibatch
            minibatch, size_so_far = [], 0
        elif size_so_far > batch_size:
            yield minibatch[:-1]
            minibatch, size_so_far = minibatch[-1:], batch_size_fn(ex, 1, 0)
    if minibatch:
        yield minibatch

このメソッドはデータセットをミニバッチ化して、イテレーションの度にミニバッチを渡す役割を担う。

現段階で辞書や単語インデックスは登場していない。

さて、__iter__()メソッドに戻る。

def __iter__(self):
    while True:
        self.init_epoch()
        for idx, minibatch in enumerate(self.batches):
            ..
            yield Batch(minibatch, self.dataset, self.device)

最後のyieldまで飛ばす。
結論を言うと、このBatchクラスが単語のインデックス化を行うクラスだった。
このオブジェクトをみていく。

class Batch(object):

    def __init__(self, data=None, dataset=None, device=None):
        ...
            for (name, field) in dataset.fields.items():
                if field is not None:
                    batch = [getattr(x, name) for x in data]
                    setattr(self, name, field.process(batch, device=device))

大事なのは最後のsetattr()の部分。
setattr()はオブジェクトに属性を追加を追加する組み込み関数。
詳細はここに分かりやすく書いてある。
setattr属性の追加 - Python学習講座

つまり、バッチオブジェクトのname(field名)属性にfield.process()の結果を追加する。
だいぶ目的に近づいてきた。(後二つで答えにたどり着く。)

def process(self, batch, device=None):
    padded = self.pad(batch)
    tensor = self.numericalize(padded, device=device)
    return tensor

さてインデックス化をしてそうなメソッド名がある。numericalizeである。

def numericalize(self, arr, device=None):
    if self.use_vocab:
        if self.sequential:
            arr = [[self.vocab.stoi[x] for x in ex] for ex in arr]
        else:
            arr = [self.vocab.stoi[x] for x in arr]
    var = torch.tensor(arr, dtype=self.dtype, device=device)

※長かったので一部のみ抜粋。
全コードは下記参照
github.com

ようやく目的のvocab.stoiが見つかった。

ここまで追えればもう少しtorchetxtを理解して使えそうだ。