Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7eb73782
Commit
7eb73782
authored
Mar 02, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
shuffle cifar/mnist data
parent
ef4a15ca
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
7 deletions
+17
-7
example_cifar10.py
example_cifar10.py
+1
-1
tensorpack/dataflow/dataset/cifar10.py
tensorpack/dataflow/dataset/cifar10.py
+9
-3
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+7
-3
No files found.
example_cifar10.py
View file @
7eb73782
...
...
@@ -18,7 +18,7 @@ from tensorpack.dataflow import *
from
tensorpack.dataflow
import
imgaug
"""
CIFAR10 90
%
validation accuracy after 100k step
, 91
%
after 160k step.
CIFAR10 90
%
validation accuracy after 100k step
"""
BATCH_SIZE
=
128
...
...
tensorpack/dataflow/dataset/cifar10.py
View file @
7eb73782
...
...
@@ -5,6 +5,7 @@
import
os
,
sys
import
pickle
import
numpy
as
np
import
random
from
six.moves
import
urllib
,
range
import
copy
import
tarfile
...
...
@@ -65,10 +66,11 @@ class Cifar10(DataFlow):
Return [image, label],
image is 32x32x3 in the range [0,255]
"""
def
__init__
(
self
,
train_or_test
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
"""
Args:
train_or_test: string either 'train' or 'test'
shuffle: default to True
"""
assert
train_or_test
in
[
'train'
,
'test'
]
if
dir
is
None
:
...
...
@@ -86,13 +88,17 @@ class Cifar10(DataFlow):
self
.
train_or_test
=
train_or_test
self
.
dir
=
dir
self
.
data
=
read_cifar10
(
self
.
fs
)
self
.
shuffle
=
shuffle
def
size
(
self
):
return
50000
if
self
.
train_or_test
==
'train'
else
10000
def
get_data
(
self
):
for
k
in
self
.
data
:
yield
k
idxs
=
list
(
range
(
len
(
self
.
data
)))
if
self
.
shuffle
:
random
.
shuffle
(
idxs
)
for
k
in
idxs
:
yield
self
.
data
[
k
]
def
get_per_pixel_mean
(
self
):
"""
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
7eb73782
...
...
@@ -5,7 +5,7 @@
import
os
import
gzip
import
random
import
numpy
from
six.moves
import
urllib
,
range
...
...
@@ -100,7 +100,7 @@ class Mnist(DataFlow):
Return [image, label],
image is 28x28 in the range [0,1]
"""
def
__init__
(
self
,
train_or_test
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
"""
Args:
train_or_test: string either 'train' or 'test'
...
...
@@ -109,6 +109,7 @@ class Mnist(DataFlow):
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'mnist_data'
)
assert
train_or_test
in
[
'train'
,
'test'
]
self
.
train_or_test
=
train_or_test
self
.
shuffle
=
shuffle
TRAIN_IMAGES
=
'train-images-idx3-ubyte.gz'
TRAIN_LABELS
=
'train-labels-idx1-ubyte.gz'
...
...
@@ -136,7 +137,10 @@ class Mnist(DataFlow):
def
get_data
(
self
):
ds
=
self
.
train
if
self
.
train_or_test
==
'train'
else
self
.
test
for
k
in
range
(
ds
.
num_examples
):
idxs
=
list
(
range
(
ds
.
num_examples
))
if
self
.
shuffle
:
random
.
shuffle
(
idxs
)
for
k
in
idxs
:
img
=
ds
.
images
[
k
]
.
reshape
((
28
,
28
))
label
=
ds
.
labels
[
k
]
yield
[
img
,
label
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment