PyTorch入門④:utilsを使う。(torch.utils.data)
torch.utils.data
データセット読み込み関連ユーティリティ。
DataLoader
は、データのロード・前処理をするためのモジュール。必ずしもこれを使わなければいけないことは無いが、前処理を楽にしてくれる。
等をオプション1つでやってくれるので便利。
実際にどのように使うか見てみる。
(やってることは前回と同じなのでバッチ処理の部分のみ見ていく。)
尚、全体のNote bookはここにまとめた。
github.com
data.TensorDataset
train_ = torch.utils.data.TensorDataset(torch.from_numpy(trX).float(), torch.from_numpy(trY.astype(np.int64)))
dataとtargetをデータセットとしてまとめる。
data.DataLoader
train_iter = torch.utils.data.DataLoader(train_, batch_size=64, shuffle=True)
batch_size
にサイズを指定、shuffle=True
を指定するだけ。
とても簡単。
以下でバッチ化されたサンプルを確認できる。
train_ = next(iter(train_iter)) print(train_)
その他のオプションはドキュメント参照
torch.utils.data — PyTorch master documentation
main
訓練は以下のようにloaderをイテレーションする。
torch.manual_seed(1) for epoch in range(N_EPOCHS): loss = 0 for i, train_data in enumerate(train_iter): inputs, labels = train_data loss += train(model, loss_func, optimizer, inputs, labels)
とても簡単で便利なので活用したい。