Commit 1d74ac21 authored by Yuxin Wu's avatar Yuxin Wu

add FashionMnist

parent 994a150b
...@@ -8,7 +8,7 @@ Feature Requests: ...@@ -8,7 +8,7 @@ Feature Requests:
2. Add a new feature. Please note that, you can implement a lot of features by extending tensorpack 2. Add a new feature. Please note that, you can implement a lot of features by extending tensorpack
(See http://tensorpack.readthedocs.io/en/latest/tutorial/index.html#extend-tensorpack). (See http://tensorpack.readthedocs.io/en/latest/tutorial/index.html#extend-tensorpack).
It may not have to be added to tensorpack unless you have a good reason. It may not have to be added to tensorpack unless you have a good reason.
3. Note that we don't take "example requests". 3. Note that we don't implement papers at other's requests.
Usage Questions: Usage Questions:
Usage questions are like "How do I do [this specific thing] in tensorpack?". Usage questions are like "How do I do [this specific thing] in tensorpack?".
......
...@@ -12,17 +12,16 @@ from ...utils import logger ...@@ -12,17 +12,16 @@ from ...utils import logger
from ...utils.fs import download, get_dataset_path from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow from ..base import RNGDataFlow
__all__ = ['Mnist'] __all__ = ['Mnist', 'FashionMnist']
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(url, work_directory):
def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here.""" """Download the data from Yann's website, unless it's already here."""
filename = url.split('/')[-1]
filepath = os.path.join(work_directory, filename) filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath): if not os.path.exists(filepath):
logger.info("Downloading mnist data to {}...".format(filepath)) logger.info("Downloading to {}...".format(filepath))
download(SOURCE_URL + filename, work_directory) download(url, work_directory)
return filepath return filepath
...@@ -69,6 +68,9 @@ class Mnist(RNGDataFlow): ...@@ -69,6 +68,9 @@ class Mnist(RNGDataFlow):
image is 28x28 in the range [0,1], label is an int. image is 28x28 in the range [0,1], label is an int.
""" """
DIR_NAME = 'mnist_data'
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def __init__(self, train_or_test, shuffle=True, dir=None): def __init__(self, train_or_test, shuffle=True, dir=None):
""" """
Args: Args:
...@@ -76,15 +78,15 @@ class Mnist(RNGDataFlow): ...@@ -76,15 +78,15 @@ class Mnist(RNGDataFlow):
shuffle (bool): shuffle the dataset shuffle (bool): shuffle the dataset
""" """
if dir is None: if dir is None:
dir = get_dataset_path('mnist_data') dir = get_dataset_path(self.DIR_NAME)
assert train_or_test in ['train', 'test'] assert train_or_test in ['train', 'test']
self.train_or_test = train_or_test self.train_or_test = train_or_test
self.shuffle = shuffle self.shuffle = shuffle
def get_images_and_labels(image_file, label_file): def get_images_and_labels(image_file, label_file):
f = maybe_download(image_file, dir) f = maybe_download(self.SOURCE_URL + image_file, dir)
images = extract_images(f) images = extract_images(f)
f = maybe_download(label_file, dir) f = maybe_download(self.SOURCE_URL + label_file, dir)
labels = extract_labels(f) labels = extract_labels(f)
assert images.shape[0] == labels.shape[0] assert images.shape[0] == labels.shape[0]
return images, labels return images, labels
...@@ -111,6 +113,11 @@ class Mnist(RNGDataFlow): ...@@ -111,6 +113,11 @@ class Mnist(RNGDataFlow):
yield [img, label] yield [img, label]
class FashionMnist(Mnist):
DIR_NAME = 'fashion_mnist_data'
SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
if __name__ == '__main__': if __name__ == '__main__':
ds = Mnist('train') ds = Mnist('train')
for (img, label) in ds.get_data(): for (img, label) in ds.get_data():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment