Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8c57fc1f
Commit
8c57fc1f
authored
Dec 26, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
mnist standalone dataset
parent
7f3baaa1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
111 additions
and
3 deletions
+111
-3
dataflow/dataset/mnist.py
dataflow/dataset/mnist.py
+111
-3
No files found.
dataflow/dataset/mnist.py
View file @
8c57fc1f
...
...
@@ -3,12 +3,120 @@
# File: mnist.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
os
from
tensorflow.examples.tutorials.mnist
import
input_data
import
gzip
import
numpy
from
six.moves
import
urllib
__all__
=
[
'Mnist'
]
SOURCE_URL
=
'http://yann.lecun.com/exdb/mnist/'
def
maybe_download
(
filename
,
work_directory
):
"""Download the data from Yann's website, unless it's already here."""
if
not
os
.
path
.
exists
(
work_directory
):
os
.
mkdir
(
work_directory
)
filepath
=
os
.
path
.
join
(
work_directory
,
filename
)
if
not
os
.
path
.
exists
(
filepath
):
filepath
,
_
=
urllib
.
request
.
urlretrieve
(
SOURCE_URL
+
filename
,
filepath
)
statinfo
=
os
.
stat
(
filepath
)
print
(
'Successfully downloaded'
,
filename
,
statinfo
.
st_size
,
'bytes.'
)
return
filepath
def
_read32
(
bytestream
):
dt
=
numpy
.
dtype
(
numpy
.
uint32
)
.
newbyteorder
(
'>'
)
return
numpy
.
frombuffer
(
bytestream
.
read
(
4
),
dtype
=
dt
)[
0
]
def
extract_images
(
filename
):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
print
(
'Extracting'
,
filename
)
with
gzip
.
open
(
filename
)
as
bytestream
:
magic
=
_read32
(
bytestream
)
if
magic
!=
2051
:
raise
ValueError
(
'Invalid magic number
%
d in MNIST image file:
%
s'
%
(
magic
,
filename
))
num_images
=
_read32
(
bytestream
)
rows
=
_read32
(
bytestream
)
cols
=
_read32
(
bytestream
)
buf
=
bytestream
.
read
(
rows
*
cols
*
num_images
)
data
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
data
=
data
.
reshape
(
num_images
,
rows
,
cols
,
1
)
return
data
def
extract_labels
(
filename
):
"""Extract the labels into a 1D uint8 numpy array [index]."""
print
(
'Extracting'
,
filename
)
with
gzip
.
open
(
filename
)
as
bytestream
:
magic
=
_read32
(
bytestream
)
if
magic
!=
2049
:
raise
ValueError
(
'Invalid magic number
%
d in MNIST label file:
%
s'
%
(
magic
,
filename
))
num_items
=
_read32
(
bytestream
)
buf
=
bytestream
.
read
(
num_items
)
labels
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
return
labels
class
DataSet
(
object
):
def
__init__
(
self
,
images
,
labels
,
fake_data
=
False
):
"""Construct a DataSet. """
assert
images
.
shape
[
0
]
==
labels
.
shape
[
0
],
(
'images.shape:
%
s labels.shape:
%
s'
%
(
images
.
shape
,
labels
.
shape
))
self
.
_num_examples
=
images
.
shape
[
0
]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns] (assuming depth == 1)
assert
images
.
shape
[
3
]
==
1
images
=
images
.
reshape
(
images
.
shape
[
0
],
images
.
shape
[
1
]
*
images
.
shape
[
2
])
# Convert from [0, 255] -> [0.0, 1.0].
images
=
images
.
astype
(
numpy
.
float32
)
images
=
numpy
.
multiply
(
images
,
1.0
/
255.0
)
self
.
_images
=
images
self
.
_labels
=
labels
@
property
def
images
(
self
):
return
self
.
_images
@
property
def
labels
(
self
):
return
self
.
_labels
@
property
def
num_examples
(
self
):
return
self
.
_num_examples
def
read_data_sets
(
train_dir
):
class
DataSets
(
object
):
pass
data_sets
=
DataSets
()
TRAIN_IMAGES
=
'train-images-idx3-ubyte.gz'
TRAIN_LABELS
=
'train-labels-idx1-ubyte.gz'
TEST_IMAGES
=
't10k-images-idx3-ubyte.gz'
TEST_LABELS
=
't10k-labels-idx1-ubyte.gz'
local_file
=
maybe_download
(
TRAIN_IMAGES
,
train_dir
)
train_images
=
extract_images
(
local_file
)
local_file
=
maybe_download
(
TRAIN_LABELS
,
train_dir
)
train_labels
=
extract_labels
(
local_file
)
local_file
=
maybe_download
(
TEST_IMAGES
,
train_dir
)
test_images
=
extract_images
(
local_file
)
local_file
=
maybe_download
(
TEST_LABELS
,
train_dir
)
test_labels
=
extract_labels
(
local_file
)
data_sets
.
train
=
DataSet
(
train_images
,
train_labels
)
data_sets
.
test
=
DataSet
(
test_images
,
test_labels
)
return
data_sets
class
Mnist
(
object
):
def
__init__
(
self
,
train_or_test
,
dir
=
None
):
"""
...
...
@@ -17,7 +125,7 @@ class Mnist(object):
"""
if
dir
is
None
:
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'mnist_data'
)
self
.
dataset
=
input_data
.
read_data_sets
(
dir
)
self
.
dataset
=
read_data_sets
(
dir
)
self
.
train_or_test
=
train_or_test
def
get_data
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment