Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
68e8d9eb
Commit
68e8d9eb
authored
Dec 09, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
MultiProcessMapData with strict (#414)
parent
be3a07a1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
99 additions
and
75 deletions
+99
-75
.github/ISSUE_TEMPLATE.md
.github/ISSUE_TEMPLATE.md
+2
-8
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+97
-67
No files found.
.github/ISSUE_TEMPLATE.md
View file @
68e8d9eb
Bug Reports/Feature Requests/Usage Questions Only:
Bug
Reports
: PLEASE always include
Bug
reports or other problems with code
: PLEASE always include
1.
What you did. (command you run and changes you made if using examples; post or describe your code if not)
2.
What you observed, e.g.
logs
.
2.
What you observed, e.g.
as much as logs possible
.
3.
What you expected, if not obvious.
4.
Your environment (TF version, cudnn version, number & type of GPUs), if it matters.
5.
About efficiency, PLEASE first read http://tensorpack.readthedocs.io/en/latest/tutorial/performance-tuning.html
...
...
@@ -14,10 +14,4 @@ Feature Requests:
It may not have to be added to tensorpack unless you have a good reason.
3.
Note that we don't implement papers at others' requests.
Usage Questions, e.g.:
"How do I do [this specific thing] in tensorpack?"
"Why certain examples need to be written in this way?"
We don't answer general machine learning questions like:
"I want to do [this machine learning task]. What specific things do I need to do?"
You can also use gitter (https://gitter.im/tensorpack/users) for more casual discussions.
tensorpack/dataflow/prefetch.py
View file @
68e8d9eb
...
...
@@ -67,20 +67,11 @@ def _zmq_catch_error(name):
class
_MultiProcessZMQDataFlow
(
DataFlow
):
def
__init__
(
self
,
ds
):
def
__init__
(
self
):
assert
os
.
name
!=
'nt'
,
"ZMQ IPC doesn't support windows!"
self
.
_reset_done
=
False
self
.
_procs
=
[]
self
.
ds
=
ds
try
:
self
.
_size
=
ds
.
size
()
except
NotImplementedError
:
self
.
_size
=
-
1
def
size
(
self
):
return
self
.
ds
.
size
()
def
reset_state
(
self
):
"""
All forked dataflows are reset **once and only once** in spawned processes.
...
...
@@ -265,10 +256,17 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
if
nr_proc
>
1
:
logger
.
info
(
"[PrefetchDataZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d."
)
try
:
self
.
_size
=
ds
.
size
()
except
NotImplementedError
:
self
.
_size
=
-
1
def
_recv
(
self
):
return
loads
(
self
.
socket
.
recv
(
copy
=
False
)
.
bytes
)
def
size
(
self
):
return
self
.
ds
.
size
()
def
get_data
(
self
):
with
self
.
_guard
,
_zmq_catch_error
(
'PrefetchDataZMQ'
):
for
k
in
itertools
.
count
():
...
...
@@ -311,7 +309,59 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
proc
.
start
()
class
MultiThreadMapData
(
ProxyDataFlow
):
class
_ParallelMapData
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
buffer_size
):
super
(
_ParallelMapData
,
self
)
.
__init__
(
ds
)
assert
buffer_size
>
0
,
buffer_size
self
.
_buffer_size
=
buffer_size
def
_recv
(
self
):
pass
def
_send
(
self
,
dp
):
pass
def
_recv_filter_none
(
self
):
ret
=
self
.
_recv
()
assert
ret
is
not
None
,
\
"[{}] Map function cannot return None when strict mode is used."
.
format
(
type
(
self
)
.
__name__
)
return
ret
def
_fill_buffer
(
self
):
try
:
for
_
in
range
(
self
.
_buffer_size
):
dp
=
next
(
self
.
_iter
)
self
.
_send
(
dp
)
except
StopIteration
:
logger
.
error
(
"[{}] buffer_size cannot be larger than the size of the DataFlow!"
.
format
(
type
(
self
)
.
__name__
))
raise
def
get_data_non_strict
(
self
):
for
dp
in
self
.
_iter
:
self
.
_send
(
dp
)
yield
self
.
_recv
()
self
.
_iter
=
self
.
ds
.
get_data
()
# refresh
for
_
in
range
(
self
.
_buffer_size
):
self
.
_send
(
next
(
self
.
_iter
))
yield
self
.
_recv
()
def
get_data_strict
(
self
):
for
dp
in
self
.
_iter
:
self
.
_send
(
dp
)
yield
self
.
_recv_filter_none
()
self
.
_iter
=
self
.
ds
.
get_data
()
# refresh
# first clear the buffer, then fill
for
k
in
range
(
self
.
_buffer_size
):
dp
=
self
.
_recv_filter_none
()
if
k
==
self
.
_buffer_size
-
1
:
self
.
_fill_buffer
()
yield
dp
class
MultiThreadMapData
(
_ParallelMapData
):
"""
Same as :class:`MapData`, but start threads to run the mapping function.
This is useful when the mapping function is the bottleneck, but you don't
...
...
@@ -367,11 +417,10 @@ class MultiThreadMapData(ProxyDataFlow):
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
super
(
MultiThreadMapData
,
self
)
.
__init__
(
ds
)
super
(
MultiThreadMapData
,
self
)
.
__init__
(
ds
,
buffer_size
)
self
.
_strict
=
strict
self
.
nr_thread
=
nr_thread
self
.
buffer_size
=
buffer_size
self
.
map_func
=
map_func
self
.
_threads
=
[]
self
.
_evt
=
None
...
...
@@ -398,43 +447,20 @@ class MultiThreadMapData(ProxyDataFlow):
# only call once, to ensure inq+outq has a total of buffer_size elements
self
.
_fill_buffer
()
def
_fill_buffer
(
self
):
n
=
self
.
buffer_size
-
self
.
_in_queue
.
qsize
()
-
self
.
_out_queue
.
qsize
()
assert
n
>=
0
,
n
if
n
==
0
:
return
try
:
for
_
in
range
(
n
):
self
.
_in_queue
.
put
(
next
(
self
.
_iter
))
except
StopIteration
:
logger
.
error
(
"[MultiThreadMapData] buffer_size cannot be larger than the size of the DataFlow!"
)
raise
def
_recv
(
self
):
ret
=
self
.
_out_queue
.
get
()
if
ret
is
None
:
assert
not
self
.
_strict
,
\
"[MultiThreadMapData] Map function cannot return None when strict mode is used."
return
ret
return
self
.
_out_queue
.
get
()
def
_send
(
self
,
dp
):
self
.
_in_queue
.
put
(
dp
)
def
get_data
(
self
):
with
self
.
_guard
:
for
dp
in
self
.
_iter
:
self
.
_in_queue
.
put
(
dp
)
yield
self
.
_recv
()
self
.
_iter
=
self
.
ds
.
get_data
()
if
self
.
_strict
:
# first call get() to clear the queues, then fill
for
k
in
range
(
self
.
buffer_size
):
dp
=
self
.
_recv
()
if
k
==
self
.
buffer_size
-
1
:
self
.
_fill_buffer
()
for
dp
in
self
.
get_data_strict
():
yield
dp
else
:
for
_
in
range
(
self
.
buffer_size
):
self
.
_in_queue
.
put
(
next
(
self
.
_iter
))
yield
self
.
_recv
()
for
dp
in
self
.
get_data_non_strict
():
yield
dp
def
__del__
(
self
):
if
self
.
_evt
is
not
None
:
...
...
@@ -447,10 +473,20 @@ class MultiThreadMapData(ProxyDataFlow):
ThreadedMapData
=
MultiThreadMapData
class
MultiProcessMapDataZMQ
(
_MultiProcessZMQDataFlow
):
class
MultiProcessMapDataZMQ
(
_
ParallelMapData
,
_
MultiProcessZMQDataFlow
):
"""
Same as :class:`MapData`, but start processes to run the mapping function,
and communicate with ZeroMQ pipe.
Note:
1. Processes run in parallel and can take different time to run the
mapping function. Therefore the order of datapoints won't be
preserved, and datapoints from one pass of `df.get_data()` might get
mixed with datapoints from the next pass.
You can use **strict mode**, where `MultiProcessMapData.get_data()`
is guranteed to produce the exact set which `df.get_data()`
produces. Although the order of data still isn't preserved.
"""
class
_Worker
(
mp
.
Process
):
def
__init__
(
self
,
identity
,
map_func
,
pipename
,
hwm
):
...
...
@@ -472,30 +508,32 @@ class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow):
dp
=
self
.
map_func
(
dp
)
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
def
__init__
(
self
,
ds
,
nr_proc
,
map_func
,
buffer_size
=
200
):
def
__init__
(
self
,
ds
,
nr_proc
,
map_func
,
buffer_size
=
200
,
strict
=
False
):
"""
Args:
ds (DataFlow): the dataflow to map
nr_proc(int): number of threads to use
map_func (callable): datapoint -> datapoint | None
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
super
(
MultiProcessMapDataZMQ
,
self
)
.
__init__
(
ds
)
_ParallelMapData
.
__init__
(
self
,
ds
,
buffer_size
)
_MultiProcessZMQDataFlow
.
__init__
(
self
)
self
.
nr_proc
=
nr_proc
self
.
map_func
=
map_func
self
.
buffer_size
=
buffer_size
self
.
_strict
=
strict
self
.
_procs
=
[]
self
.
_guard
=
DataFlowReentrantGuard
()
def
_reset_once
(
self
):
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
socket
.
set_hwm
(
self
.
buffer_size
*
2
)
self
.
socket
.
set_hwm
(
self
.
_
buffer_size
*
2
)
pipename
=
_get_pipe_name
(
'dataflow-map'
)
_bind_guard
(
self
.
socket
,
pipename
)
self
.
_proc_ids
=
[
u'{}'
.
format
(
k
)
.
encode
(
'utf-8'
)
for
k
in
range
(
self
.
nr_proc
)]
worker_hwm
=
int
(
self
.
buffer_size
*
2
//
self
.
nr_proc
)
worker_hwm
=
int
(
self
.
_
buffer_size
*
2
//
self
.
nr_proc
)
self
.
_procs
=
[
MultiProcessMapDataZMQ
.
_Worker
(
self
.
_proc_ids
[
k
],
self
.
map_func
,
pipename
,
worker_hwm
)
for
k
in
range
(
self
.
nr_proc
)]
...
...
@@ -507,14 +545,8 @@ class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow):
self
.
_start_processes
()
self
.
_fill_buffer
()
def
_fill_buffer
(
self
):
# Filling the buffer.
try
:
for
_
in
range
(
self
.
buffer_size
):
self
.
_send
(
next
(
self
.
_iter
))
except
StopIteration
:
logger
.
error
(
"[MultiProcessMapData] buffer_size cannot be larger than the size of the DataFlow!"
)
raise
def
reset_state
(
self
):
_MultiProcessZMQDataFlow
.
reset_state
(
self
)
def
_send
(
self
,
dp
):
# round-robin assignment
...
...
@@ -529,14 +561,12 @@ class MultiProcessMapDataZMQ(_MultiProcessZMQDataFlow):
def
get_data
(
self
):
with
self
.
_guard
,
_zmq_catch_error
(
'MultiProcessMapData'
):
for
dp
in
self
.
_iter
:
self
.
_send
(
dp
)
yield
self
.
_recv
()
self
.
_iter
=
self
.
ds
.
get_data
()
# refresh
for
_
in
range
(
self
.
buffer_size
):
self
.
_send
(
next
(
self
.
_iter
))
yield
self
.
_recv
()
if
self
.
_strict
:
for
dp
in
self
.
get_data_strict
():
yield
dp
else
:
for
dp
in
self
.
get_data_non_strict
():
yield
dp
MultiProcessMapData
=
MultiProcessMapDataZMQ
# alias
...
...
@@ -549,13 +579,13 @@ if __name__ == '__main__':
def
get_data
(
self
):
for
k
in
range
(
self
.
_size
):
yield
[
0
]
yield
[
k
]
def
size
(
self
):
return
self
.
_size
ds
=
Zero
(
300
)
ds
=
MultiProcessMapData
(
ds
,
3
,
lambda
x
:
[
x
[
0
]
+
1
])
ds
=
MultiProcessMapData
(
ds
,
3
,
lambda
x
:
[
x
[
0
]
+
1
]
,
strict
=
True
)
ds
.
reset_state
()
for
k
in
ds
.
get_data
():
print
(
"Bang!"
,
k
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment