機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

PyTorch入門④:utilsを使う。(torch.utils.data)

torch.utils.data



データセット読み込み関連ユーティリティ。
DataLoaderは、データのロード・前処理をするためのモジュール。
必ずしもこれを使わなければいけないことは無いが、前処理を楽にしてくれる。

等をオプション1つでやってくれるので便利。
実際にどのように使うか見てみる。
(やってることは前回と同じなのでバッチ処理の部分のみ見ていく。)

尚、全体のNote bookはここにまとめた。
github.com

data.TensorDataset
train_ = torch.utils.data.TensorDataset(torch.from_numpy(trX).float(), torch.from_numpy(trY.astype(np.int64)))

dataとtargetをデータセットとしてまとめる。

data.DataLoader
train_iter = torch.utils.data.DataLoader(train_, batch_size=64, shuffle=True)

batch_sizeにサイズを指定、shuffle=Trueを指定するだけ。
とても簡単。

以下でバッチ化されたサンプルを確認できる。

train_ = next(iter(train_iter))
print(train_)

その他のオプションはドキュメント参照
torch.utils.data — PyTorch master documentation

main

訓練は以下のようにloaderをイテレーションする。

torch.manual_seed(1)
for epoch in range(N_EPOCHS):
    loss = 0
    for i, train_data in enumerate(train_iter):
        inputs, labels = train_data
        loss += train(model, loss_func, optimizer, inputs, labels)

とても簡単で便利なので活用したい。