Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0ba0336c
Commit
0ba0336c
authored
Feb 12, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
prefetch with multiprocessing
parent
f7af025e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
52 additions
and
31 deletions
+52
-31
example_cifar10.py
example_cifar10.py
+1
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+17
-5
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+31
-23
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+2
-1
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+1
-1
No files found.
example_cifar10.py
View file @
0ba0336c
...
...
@@ -20,7 +20,7 @@ from tensorpack.dataflow import imgaug
"""
This config follows the same preprocessing/model/hyperparemeters as in
tensorflow cifar10 examples. (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/image/cifar10/)
But it's
faster.
86
%
accuracy.
faster.
"""
BATCH_SIZE
=
128
...
...
tensorpack/dataflow/common.py
View file @
0ba0336c
...
...
@@ -63,18 +63,26 @@ class FixedSizeData(DataFlow):
def
__init__
(
self
,
ds
,
size
):
self
.
ds
=
ds
self
.
_size
=
size
self
.
itr
=
None
def
size
(
self
):
return
self
.
_size
def
get_data
(
self
):
if
self
.
itr
is
None
:
self
.
itr
=
self
.
ds
.
get_data
()
cnt
=
0
while
True
:
for
dp
in
self
.
ds
.
get_data
():
cnt
+=
1
yield
dp
if
cnt
==
self
.
_size
:
return
try
:
dp
=
self
.
itr
.
next
()
except
StopIteration
:
self
.
itr
=
self
.
ds
.
get_data
()
dp
=
self
.
itr
.
next
()
cnt
+=
1
yield
dp
if
cnt
==
self
.
_size
:
return
class
RepeatedData
(
DataFlow
):
""" repeat another dataflow for certain times"""
...
...
@@ -93,6 +101,9 @@ class RepeatedData(DataFlow):
class
FakeData
(
DataFlow
):
""" Build fake random data of given shapes"""
def
__init__
(
self
,
shapes
,
size
):
"""
shapes: list of list/tuple
"""
self
.
shapes
=
shapes
self
.
_size
=
size
...
...
@@ -126,6 +137,7 @@ def AugmentImageComponent(ds, augmentors, index=0):
augmentors: a list of ImageAugmentor instance
index: the index of image in each data point. default to be 0
"""
# TODO reset rng at the beginning of each get_data
aug
=
AugmentorList
(
augmentors
)
return
MapData
(
ds
,
...
...
tensorpack/dataflow/prefetch.py
View file @
0ba0336c
...
...
@@ -13,40 +13,48 @@ class Sentinel:
pass
class
PrefetchProcess
(
multiprocessing
.
Process
):
def
__init__
(
self
,
ds
,
queue_size
):
def
__init__
(
self
,
ds
,
queue
):
"""
ds: ds to take data from
queue: output queue to put results in
"""
super
(
PrefetchProcess
,
self
)
.
__init__
()
self
.
ds
=
ds
self
.
queue
=
multiprocessing
.
Queue
(
queue_size
)
self
.
queue
=
queue
def
run
(
self
):
for
dp
in
self
.
ds
.
get_data
():
self
.
queue
.
put
(
dp
)
self
.
queue
.
put
(
Sentinel
())
try
:
for
dp
in
self
.
ds
.
get_data
():
self
.
queue
.
put
(
dp
)
finally
:
self
.
queue
.
put
(
Sentinel
())
def
get_data
(
self
):
while
True
:
ret
=
self
.
queue
.
get
()
if
isinstance
(
ret
,
Sentinel
):
return
yield
ret
class
PrefetchData
(
DataFlow
):
def
__init__
(
self
,
ds
,
nr_prefetch
):
def
__init__
(
self
,
ds
,
nr_prefetch
,
nr_proc
=
1
):
"""
use multiprocess, will duplicate ds by nr_proc times
"""
self
.
ds
=
ds
self
.
nr_prefetch
=
int
(
nr_prefetch
)
assert
self
.
nr_prefetch
>
0
def
size
(
self
):
return
self
.
ds
.
size
()
self
.
nr_proc
=
nr_proc
self
.
nr_prefetch
=
nr_prefetch
def
get_data
(
self
):
worker
=
PrefetchProcess
(
self
.
ds
,
self
.
nr_prefetch
)
# TODO register terminate function
worker
.
start
()
queue
=
multiprocessing
.
Queue
(
self
.
nr_prefetch
)
procs
=
[
PrefetchProcess
(
self
.
ds
,
queue
)
for
_
in
range
(
self
.
nr_proc
)]
[
x
.
start
()
for
x
in
procs
]
end_cnt
=
0
try
:
for
dp
in
worker
.
get_data
():
while
True
:
dp
=
queue
.
get
()
if
isinstance
(
dp
,
Sentinel
):
end_cnt
+=
1
if
end_cnt
==
self
.
nr_proc
:
break
continue
yield
dp
finally
:
worker
.
join
()
worker
.
terminate
()
queue
.
close
()
[
x
.
terminate
()
for
x
in
procs
]
tensorpack/models/batch_norm.py
View file @
0ba0336c
...
...
@@ -11,10 +11,11 @@ __all__ = ['BatchNorm']
# http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow
# Only work for 4D tensor right now: #804
@
layer_register
()
def
BatchNorm
(
x
,
is_training
):
"""
x:
has to be BHWC for now
x:
BHWC tensor
is_training: bool
"""
is_training
=
bool
(
is_training
)
...
...
tensorpack/models/conv2d.py
View file @
0ba0336c
...
...
@@ -11,7 +11,7 @@ __all__ = ['Conv2D']
@
layer_register
(
summary_activation
=
True
)
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
padding
=
'
VALID
'
,
stride
=
1
,
padding
=
'
SAME
'
,
stride
=
1
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
nn
.
relu
,
split
=
1
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment