Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0a012166
Commit
0a012166
authored
Feb 22, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
cifar10 mean
parent
bac11ae3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
9 deletions
+41
-9
tensorpack/dataflow/dataset/cifar10.py
tensorpack/dataflow/dataset/cifar10.py
+39
-8
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+2
-1
No files found.
tensorpack/dataflow/dataset/cifar10.py
View file @
0a012166
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
os
,
sys
import
os
,
sys
import
pickle
import
pickle
import
numpy
import
numpy
as
np
from
six.moves
import
urllib
from
six.moves
import
urllib
import
tarfile
import
tarfile
import
logging
import
logging
...
@@ -39,6 +39,7 @@ def maybe_download_and_extract(dest_directory):
...
@@ -39,6 +39,7 @@ def maybe_download_and_extract(dest_directory):
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
dest_directory
)
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
dest_directory
)
def
read_cifar10
(
filenames
):
def
read_cifar10
(
filenames
):
ret
=
[]
for
fname
in
filenames
:
for
fname
in
filenames
:
fo
=
open
(
fname
,
'rb'
)
fo
=
open
(
fname
,
'rb'
)
dic
=
pickle
.
load
(
fo
)
dic
=
pickle
.
load
(
fo
)
...
@@ -47,8 +48,16 @@ def read_cifar10(filenames):
...
@@ -47,8 +48,16 @@ def read_cifar10(filenames):
fo
.
close
()
fo
.
close
()
for
k
in
xrange
(
10000
):
for
k
in
xrange
(
10000
):
img
=
data
[
k
]
.
reshape
(
3
,
32
,
32
)
img
=
data
[
k
]
.
reshape
(
3
,
32
,
32
)
img
=
numpy
.
transpose
(
img
,
[
1
,
2
,
0
])
img
=
np
.
transpose
(
img
,
[
1
,
2
,
0
])
yield
[
img
,
label
[
k
]]
ret
.
append
([
img
,
label
[
k
]])
return
ret
def
get_filenames
(
dir
):
filenames
=
[
os
.
path
.
join
(
dir
,
'cifar-10-batches-py'
,
'data_batch_
%
d'
%
i
)
for
i
in
xrange
(
1
,
6
)]
filenames
.
append
(
os
.
path
.
join
(
dir
,
'cifar-10-batches-py'
,
'test_batch'
))
return
filenames
class
Cifar10
(
DataFlow
):
class
Cifar10
(
DataFlow
):
"""
"""
...
@@ -65,27 +74,49 @@ class Cifar10(DataFlow):
...
@@ -65,27 +74,49 @@ class Cifar10(DataFlow):
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'cifar10_data'
)
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'cifar10_data'
)
maybe_download_and_extract
(
dir
)
maybe_download_and_extract
(
dir
)
fnames
=
get_filenames
(
dir
)
if
train_or_test
==
'train'
:
if
train_or_test
==
'train'
:
self
.
fs
=
[
os
.
path
.
join
(
self
.
fs
=
fnames
[:
5
]
dir
,
'cifar-10-batches-py'
,
'data_batch_
%
d'
%
i
)
for
i
in
xrange
(
1
,
6
)]
else
:
else
:
self
.
fs
=
[
os
.
path
.
join
(
dir
,
'cifar-10-batches-py'
,
'test_batch'
)
]
self
.
fs
=
fnames
[
-
1
]
for
f
in
self
.
fs
:
for
f
in
self
.
fs
:
if
not
os
.
path
.
isfile
(
f
):
if
not
os
.
path
.
isfile
(
f
):
raise
ValueError
(
'Failed to find file: '
+
f
)
raise
ValueError
(
'Failed to find file: '
+
f
)
self
.
train_or_test
=
train_or_test
self
.
train_or_test
=
train_or_test
self
.
dir
=
dir
self
.
data
=
read_cifar10
(
self
.
fs
)
def
size
(
self
):
def
size
(
self
):
return
50000
if
self
.
train_or_test
==
'train'
else
10000
return
50000
if
self
.
train_or_test
==
'train'
else
10000
def
get_data
(
self
):
def
get_data
(
self
):
for
k
in
read_cifar10
(
self
.
fs
)
:
for
k
in
self
.
data
:
yield
k
yield
k
def
get_per_pixel_mean
(
self
):
"""
return a mean image of all (train and test) images of size 32x32x3
"""
fnames
=
get_filenames
(
self
.
dir
)
all_imgs
=
[
x
[
0
]
for
x
in
read_cifar10
(
fnames
)]
arr
=
np
.
array
(
all_imgs
,
dtype
=
'float32'
)
mean
=
np
.
mean
(
arr
,
axis
=
0
)
return
mean
def
get_per_channel_mean
(
self
):
"""
return three values as mean of each channel
"""
mean
=
self
.
get_per_pixel_mean
()
return
np
.
mean
(
mean
,
axis
=
(
0
,
1
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
ds
=
Cifar10
(
'train'
)
ds
=
Cifar10
(
'train'
)
from
dataflow.dftools
import
dump_dataset_images
from
tensorpack.dataflow.dftools
import
dump_dataset_images
mean
=
ds
.
get_per_channel_mean
()
print
mean
dump_dataset_images
(
ds
,
'/tmp/cifar'
,
100
)
dump_dataset_images
(
ds
,
'/tmp/cifar'
,
100
)
#for (img, label) in ds.get_data():
#for (img, label) in ds.get_data():
#from IPython import embed; embed()
#from IPython import embed; embed()
#break
#break
...
...
tensorpack/dataflow/dftools.py
View file @
0a012166
...
@@ -6,8 +6,9 @@
...
@@ -6,8 +6,9 @@
import
sys
,
os
import
sys
,
os
from
scipy.misc
import
imsave
from
scipy.misc
import
imsave
from
utils.utils
import
mkdir_p
from
..
utils.utils
import
mkdir_p
# TODO name_func to write label?
def
dump_dataset_images
(
ds
,
dirname
,
max_count
=
None
,
index
=
0
):
def
dump_dataset_images
(
ds
,
dirname
,
max_count
=
None
,
index
=
0
):
""" dump images to a folder
""" dump images to a folder
index: the index of the image in a data point
index: the index of the image in a data point
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment