機械学習・自然言語処理の勉強メモ

学んだことのメモやまとめ

MNISTの読み込み&描画

調べたら描画方法が色々あり、少し迷った。
今後、無駄な時間を省くための備忘録。

ダウンロード

from six.moves import urllib
origin = (
    'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz'
)
urllib.request.urlretrieve(origin,'mnist.pkl.gz')

読み込み

import numpy as np
import gzip
import cPickle as pickle
import matplotlib.pyplot as plt

with gzip.open('mnist.pkl.gz', 'rb') as f:
    train_set, valid_set, test_set = pickle.load(f)

train_set_x, train_set_y = train_set
valid_set_x, valid_set_y = valid_set
test_set_x, test_set_y = valid_set
print "train_set_x:", train_set_x.shape
print "train_set_y:", train_set_y.shape
print "valid_set_x:", valid_set_x.shape
print "valid_set_y:", valid_set_y.shape
print "test_set_x:", test_set_x.shape
print "test_set_y:", test_set_y.shape
print "shape:", train_set_x[0].shape

# train_set_x: (50000, 784)
# train_set_y: (50000,)
# valid_set_x: (10000, 784)
# valid_set_y: (10000,)
# test_set_x: (10000, 784)
# test_set_y: (10000,)
# shape: (784,)

描画

pos = 1
for i in range(100):
    plt.subplot(10, 10, pos)
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.imshow(train_set_x[i].reshape(28, 28))
    plt.gray()
    plt.axis('off')
    pos += 1
plt.show()

こんな感じで表示される。
f:id:kento1109:20180207134558p:plain