Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1d74ac21
Commit
1d74ac21
authored
Oct 15, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add FashionMnist
parent
994a150b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
10 deletions
+17
-10
.github/ISSUE_TEMPLATE.md
.github/ISSUE_TEMPLATE.md
+1
-1
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+16
-9
No files found.
.github/ISSUE_TEMPLATE.md
View file @
1d74ac21
...
...
@@ -8,7 +8,7 @@ Feature Requests:
2.
Add a new feature. Please note that, you can implement a lot of features by extending tensorpack
(See http://tensorpack.readthedocs.io/en/latest/tutorial/index.html#extend-tensorpack).
It may not have to be added to tensorpack unless you have a good reason.
3.
Note that we don't
take "example requests"
.
3.
Note that we don't
implement papers at other's requests
.
Usage Questions:
Usage questions are like "How do I do [this specific thing] in tensorpack?".
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
1d74ac21
...
...
@@ -12,17 +12,16 @@ from ...utils import logger
from
...utils.fs
import
download
,
get_dataset_path
from
..base
import
RNGDataFlow
__all__
=
[
'Mnist'
]
__all__
=
[
'Mnist'
,
'FashionMnist'
]
SOURCE_URL
=
'http://yann.lecun.com/exdb/mnist/'
def
maybe_download
(
filename
,
work_directory
):
def
maybe_download
(
url
,
work_directory
):
"""Download the data from Yann's website, unless it's already here."""
filename
=
url
.
split
(
'/'
)[
-
1
]
filepath
=
os
.
path
.
join
(
work_directory
,
filename
)
if
not
os
.
path
.
exists
(
filepath
):
logger
.
info
(
"Downloading
mnist data
to {}..."
.
format
(
filepath
))
download
(
SOURCE_URL
+
filename
,
work_directory
)
logger
.
info
(
"Downloading to {}..."
.
format
(
filepath
))
download
(
url
,
work_directory
)
return
filepath
...
...
@@ -69,6 +68,9 @@ class Mnist(RNGDataFlow):
image is 28x28 in the range [0,1], label is an int.
"""
DIR_NAME
=
'mnist_data'
SOURCE_URL
=
'http://yann.lecun.com/exdb/mnist/'
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
"""
Args:
...
...
@@ -76,15 +78,15 @@ class Mnist(RNGDataFlow):
shuffle (bool): shuffle the dataset
"""
if
dir
is
None
:
dir
=
get_dataset_path
(
'mnist_data'
)
dir
=
get_dataset_path
(
self
.
DIR_NAME
)
assert
train_or_test
in
[
'train'
,
'test'
]
self
.
train_or_test
=
train_or_test
self
.
shuffle
=
shuffle
def
get_images_and_labels
(
image_file
,
label_file
):
f
=
maybe_download
(
image_file
,
dir
)
f
=
maybe_download
(
self
.
SOURCE_URL
+
image_file
,
dir
)
images
=
extract_images
(
f
)
f
=
maybe_download
(
label_file
,
dir
)
f
=
maybe_download
(
self
.
SOURCE_URL
+
label_file
,
dir
)
labels
=
extract_labels
(
f
)
assert
images
.
shape
[
0
]
==
labels
.
shape
[
0
]
return
images
,
labels
...
...
@@ -111,6 +113,11 @@ class Mnist(RNGDataFlow):
yield
[
img
,
label
]
class
FashionMnist
(
Mnist
):
DIR_NAME
=
'fashion_mnist_data'
SOURCE_URL
=
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
if
__name__
==
'__main__'
:
ds
=
Mnist
(
'train'
)
for
(
img
,
label
)
in
ds
.
get_data
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment