Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6ef876c9
Commit
6ef876c9
authored
Oct 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
simplify mnist dataset code
parent
26f09ada
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
57 deletions
+21
-57
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+21
-57
No files found.
tensorpack/dataflow/dataset/mnist.py
View file @
6ef876c9
...
@@ -45,9 +45,9 @@ def extract_images(filename):
...
@@ -45,9 +45,9 @@ def extract_images(filename):
buf
=
bytestream
.
read
(
rows
*
cols
*
num_images
)
buf
=
bytestream
.
read
(
rows
*
cols
*
num_images
)
data
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
data
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
data
=
data
.
reshape
(
num_images
,
rows
,
cols
,
1
)
data
=
data
.
reshape
(
num_images
,
rows
,
cols
,
1
)
data
=
data
.
astype
(
'float32'
)
/
255.0
return
data
return
data
def
extract_labels
(
filename
):
def
extract_labels
(
filename
):
"""Extract the labels into a 1D uint8 numpy array [index]."""
"""Extract the labels into a 1D uint8 numpy array [index]."""
with
gzip
.
open
(
filename
)
as
bytestream
:
with
gzip
.
open
(
filename
)
as
bytestream
:
...
@@ -61,37 +61,6 @@ def extract_labels(filename):
...
@@ -61,37 +61,6 @@ def extract_labels(filename):
labels
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
labels
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
return
labels
return
labels
class
DataSet
(
object
):
def
__init__
(
self
,
images
,
labels
,
fake_data
=
False
):
"""Construct a DataSet. """
assert
images
.
shape
[
0
]
==
labels
.
shape
[
0
],
(
'images.shape:
%
s labels.shape:
%
s'
%
(
images
.
shape
,
labels
.
shape
))
self
.
_num_examples
=
images
.
shape
[
0
]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns] (assuming depth == 1)
assert
images
.
shape
[
3
]
==
1
images
=
images
.
reshape
(
images
.
shape
[
0
],
images
.
shape
[
1
]
*
images
.
shape
[
2
])
# Convert from [0, 255] -> [0.0, 1.0].
images
=
images
.
astype
(
numpy
.
float32
)
images
=
numpy
.
multiply
(
images
,
1.0
/
255.0
)
self
.
_images
=
images
self
.
_labels
=
labels
@
property
def
images
(
self
):
return
self
.
_images
@
property
def
labels
(
self
):
return
self
.
_labels
@
property
def
num_examples
(
self
):
return
self
.
_num_examples
class
Mnist
(
RNGDataFlow
):
class
Mnist
(
RNGDataFlow
):
"""
"""
Return [image, label],
Return [image, label],
...
@@ -108,38 +77,33 @@ class Mnist(RNGDataFlow):
...
@@ -108,38 +77,33 @@ class Mnist(RNGDataFlow):
self
.
train_or_test
=
train_or_test
self
.
train_or_test
=
train_or_test
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
TRAIN_IMAGES
=
'train-images-idx3-ubyte.gz'
def
get_images_and_labels
(
image_file
,
label_file
):
TRAIN_LABELS
=
'train-labels-idx1-ubyte.gz'
f
=
maybe_download
(
image_file
,
dir
)
TEST_IMAGES
=
't10k-images-idx3-ubyte.gz'
images
=
extract_images
(
f
)
TEST_LABELS
=
't10k-labels-idx1-ubyte.gz'
f
=
maybe_download
(
label_file
,
dir
)
labels
=
extract_labels
(
f
)
local_file
=
maybe_download
(
TRAIN_IMAGES
,
dir
)
assert
images
.
shape
[
0
]
==
labels
.
shape
[
0
]
train_images
=
extract_images
(
local_file
)
return
images
,
labels
local_file
=
maybe_download
(
TRAIN_LABELS
,
dir
)
if
self
.
train_or_test
==
'train'
:
train_labels
=
extract_labels
(
local_file
)
self
.
images
,
self
.
labels
=
get_images_and_labels
(
'train-images-idx3-ubyte.gz'
,
local_file
=
maybe_download
(
TEST_IMAGES
,
dir
)
'train-labels-idx1-ubyte.gz'
)
test_images
=
extract_images
(
local_file
)
else
:
self
.
images
,
self
.
labels
=
get_images_and_labels
(
local_file
=
maybe_download
(
TEST_LABELS
,
dir
)
't10k-images-idx3-ubyte.gz'
,
test_labels
=
extract_labels
(
local_file
)
't10k-labels-idx1-ubyte.gz'
)
self
.
train
=
DataSet
(
train_images
,
train_labels
)
self
.
test
=
DataSet
(
test_images
,
test_labels
)
def
size
(
self
):
def
size
(
self
):
ds
=
self
.
train
if
self
.
train_or_test
==
'train'
else
self
.
test
return
self
.
images
.
shape
[
0
]
return
ds
.
num_examples
def
get_data
(
self
):
def
get_data
(
self
):
ds
=
self
.
train
if
self
.
train_or_test
==
'train'
else
self
.
test
idxs
=
list
(
range
(
self
.
size
()))
idxs
=
list
(
range
(
ds
.
num_examples
))
if
self
.
shuffle
:
if
self
.
shuffle
:
self
.
rng
.
shuffle
(
idxs
)
self
.
rng
.
shuffle
(
idxs
)
for
k
in
idxs
:
for
k
in
idxs
:
img
=
ds
.
images
[
k
]
.
reshape
((
28
,
28
))
img
=
self
.
images
[
k
]
.
reshape
((
28
,
28
))
label
=
ds
.
labels
[
k
]
label
=
self
.
labels
[
k
]
yield
[
img
,
label
]
yield
[
img
,
label
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment