Commit 6ef876c9 authored by Yuxin Wu's avatar Yuxin Wu

simplify mnist dataset code

parent 26f09ada
...@@ -45,9 +45,9 @@ def extract_images(filename): ...@@ -45,9 +45,9 @@ def extract_images(filename):
buf = bytestream.read(rows * cols * num_images) buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8) data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1) data = data.reshape(num_images, rows, cols, 1)
data = data.astype('float32') / 255.0
return data return data
def extract_labels(filename): def extract_labels(filename):
"""Extract the labels into a 1D uint8 numpy array [index].""" """Extract the labels into a 1D uint8 numpy array [index]."""
with gzip.open(filename) as bytestream: with gzip.open(filename) as bytestream:
...@@ -61,37 +61,6 @@ def extract_labels(filename): ...@@ -61,37 +61,6 @@ def extract_labels(filename):
labels = numpy.frombuffer(buf, dtype=numpy.uint8) labels = numpy.frombuffer(buf, dtype=numpy.uint8)
return labels return labels
class DataSet(object):
def __init__(self, images, labels, fake_data=False):
"""Construct a DataSet. """
assert images.shape[0] == labels.shape[0], (
'images.shape: %s labels.shape: %s' % (images.shape,
labels.shape))
self._num_examples = images.shape[0]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns] (assuming depth == 1)
assert images.shape[3] == 1
images = images.reshape(images.shape[0],
images.shape[1] * images.shape[2])
# Convert from [0, 255] -> [0.0, 1.0].
images = images.astype(numpy.float32)
images = numpy.multiply(images, 1.0 / 255.0)
self._images = images
self._labels = labels
@property
def images(self):
return self._images
@property
def labels(self):
return self._labels
@property
def num_examples(self):
return self._num_examples
class Mnist(RNGDataFlow): class Mnist(RNGDataFlow):
""" """
Return [image, label], Return [image, label],
...@@ -108,38 +77,33 @@ class Mnist(RNGDataFlow): ...@@ -108,38 +77,33 @@ class Mnist(RNGDataFlow):
self.train_or_test = train_or_test self.train_or_test = train_or_test
self.shuffle = shuffle self.shuffle = shuffle
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' def get_images_and_labels(image_file, label_file):
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' f = maybe_download(image_file, dir)
TEST_IMAGES = 't10k-images-idx3-ubyte.gz' images = extract_images(f)
TEST_LABELS = 't10k-labels-idx1-ubyte.gz' f = maybe_download(label_file, dir)
labels = extract_labels(f)
local_file = maybe_download(TRAIN_IMAGES, dir) assert images.shape[0] == labels.shape[0]
train_images = extract_images(local_file) return images, labels
local_file = maybe_download(TRAIN_LABELS, dir) if self.train_or_test == 'train':
train_labels = extract_labels(local_file) self.images, self.labels = get_images_and_labels(
'train-images-idx3-ubyte.gz',
local_file = maybe_download(TEST_IMAGES, dir) 'train-labels-idx1-ubyte.gz')
test_images = extract_images(local_file) else:
self.images, self.labels = get_images_and_labels(
local_file = maybe_download(TEST_LABELS, dir) 't10k-images-idx3-ubyte.gz',
test_labels = extract_labels(local_file) 't10k-labels-idx1-ubyte.gz')
self.train = DataSet(train_images, train_labels)
self.test = DataSet(test_images, test_labels)
def size(self): def size(self):
ds = self.train if self.train_or_test == 'train' else self.test return self.images.shape[0]
return ds.num_examples
def get_data(self): def get_data(self):
ds = self.train if self.train_or_test == 'train' else self.test idxs = list(range(self.size()))
idxs = list(range(ds.num_examples))
if self.shuffle: if self.shuffle:
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
img = ds.images[k].reshape((28, 28)) img = self.images[k].reshape((28, 28))
label = ds.labels[k] label = self.labels[k]
yield [img, label] yield [img, label]
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment