Commit 1d74ac21 authored by Yuxin Wu's avatar Yuxin Wu

add FashionMnist

parent 994a150b
......@@ -8,7 +8,7 @@ Feature Requests:
2. Add a new feature. Please note that, you can implement a lot of features by extending tensorpack
(See http://tensorpack.readthedocs.io/en/latest/tutorial/index.html#extend-tensorpack).
It may not have to be added to tensorpack unless you have a good reason.
3. Note that we don't take "example requests".
3. Note that we don't implement papers at other's requests.
Usage Questions:
Usage questions are like "How do I do [this specific thing] in tensorpack?".
......
......@@ -12,17 +12,16 @@ from ...utils import logger
from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow
__all__ = ['Mnist']
__all__ = ['Mnist', 'FashionMnist']
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory):
def maybe_download(url, work_directory):
"""Download the data from Yann's website, unless it's already here."""
filename = url.split('/')[-1]
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
logger.info("Downloading mnist data to {}...".format(filepath))
download(SOURCE_URL + filename, work_directory)
logger.info("Downloading to {}...".format(filepath))
download(url, work_directory)
return filepath
......@@ -69,6 +68,9 @@ class Mnist(RNGDataFlow):
image is 28x28 in the range [0,1], label is an int.
"""
DIR_NAME = 'mnist_data'
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def __init__(self, train_or_test, shuffle=True, dir=None):
"""
Args:
......@@ -76,15 +78,15 @@ class Mnist(RNGDataFlow):
shuffle (bool): shuffle the dataset
"""
if dir is None:
dir = get_dataset_path('mnist_data')
dir = get_dataset_path(self.DIR_NAME)
assert train_or_test in ['train', 'test']
self.train_or_test = train_or_test
self.shuffle = shuffle
def get_images_and_labels(image_file, label_file):
f = maybe_download(image_file, dir)
f = maybe_download(self.SOURCE_URL + image_file, dir)
images = extract_images(f)
f = maybe_download(label_file, dir)
f = maybe_download(self.SOURCE_URL + label_file, dir)
labels = extract_labels(f)
assert images.shape[0] == labels.shape[0]
return images, labels
......@@ -111,6 +113,11 @@ class Mnist(RNGDataFlow):
yield [img, label]
class FashionMnist(Mnist):
DIR_NAME = 'fashion_mnist_data'
SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
if __name__ == '__main__':
ds = Mnist('train')
for (img, label) in ds.get_data():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment