Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fa8af3d8
Commit
fa8af3d8
authored
Oct 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactor ZMQ dataflow (#414)
parent
6a4f4ee2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
146 additions
and
144 deletions
+146
-144
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+146
-144
No files found.
tensorpack/dataflow/prefetch.py
View file @
fa8af3d8
...
...
@@ -13,7 +13,7 @@ import uuid
import
os
import
zmq
from
.base
import
ProxyDataFlow
,
DataFlowTerminated
,
DataFlowReentrantGuard
from
.base
import
DataFlow
,
ProxyDataFlow
,
DataFlowTerminated
,
DataFlowReentrantGuard
from
..utils.concurrency
import
(
ensure_proc_terminate
,
mask_sigint
,
start_proc_mask_signal
,
StoppableThread
)
...
...
@@ -65,24 +65,53 @@ def _zmq_catch_error(name):
raise
class
PrefetchProcess
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
queue
,
reset_after_spawn
=
True
):
class
_MultiProcessZMQDataFlow
(
DataFlow
):
def
__init__
(
self
,
ds
):
assert
os
.
name
!=
'nt'
,
"ZMQ IPC doesn't support windows!"
self
.
_reset_done
=
False
self
.
_procs
=
[]
self
.
ds
=
ds
try
:
self
.
_size
=
ds
.
size
()
except
NotImplementedError
:
self
.
_size
=
-
1
def
size
(
self
):
return
self
.
ds
.
size
()
def
reset_state
(
self
):
"""
:param ds: ds to take data from
:param queue: output queue to put results in
All forked dataflows are reset **once and only once** in spawned processes.
Nothing more can be done when calling this method.
"""
super
(
PrefetchProcess
,
self
)
.
__init__
()
self
.
ds
=
ds
self
.
queue
=
queue
self
.
reset_after_spawn
=
reset_after_spawn
if
self
.
_reset_done
:
return
self
.
_reset_done
=
True
def
run
(
self
):
if
self
.
reset_after_spawn
:
# reset all ds so each process will produce different data
self
.
ds
.
reset_state
()
while
True
:
for
dp
in
self
.
ds
.
get_data
():
self
.
queue
.
put
(
dp
)
# __del__ not guranteed to get called at exit
import
atexit
atexit
.
register
(
lambda
x
:
x
.
__del__
(),
self
)
self
.
_reset_once
()
# build processes
def
_reset_once
(
self
):
pass
def
_start_processes
(
self
):
start_proc_mask_signal
(
self
.
_procs
)
def
__del__
(
self
):
if
not
self
.
_reset_done
:
return
if
not
self
.
context
.
closed
:
self
.
context
.
destroy
(
0
)
for
x
in
self
.
_procs
:
x
.
terminate
()
try
:
print
(
"{} successfully cleaned-up."
.
format
(
type
(
self
)
.
__name__
))
except
:
pass
class
PrefetchData
(
ProxyDataFlow
):
...
...
@@ -102,6 +131,20 @@ class PrefetchData(ProxyDataFlow):
This is different from the behavior of :class:`PrefetchDataZMQ`
4. `reset_state()` is a no-op. The worker processes won't get called.
"""
class
_Worker
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
queue
):
super
(
PrefetchData
.
_Worker
,
self
)
.
__init__
()
self
.
ds
=
ds
self
.
queue
=
queue
def
run
(
self
):
# reset all ds so each process will produce different data
self
.
ds
.
reset_state
()
while
True
:
for
dp
in
self
.
ds
.
get_data
():
self
.
queue
.
put
(
dp
)
def
__init__
(
self
,
ds
,
nr_prefetch
,
nr_proc
):
"""
Args:
...
...
@@ -119,7 +162,7 @@ class PrefetchData(ProxyDataFlow):
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
queue
=
mp
.
Queue
(
self
.
nr_prefetch
)
self
.
procs
=
[
Prefetch
Process
(
self
.
ds
,
self
.
queue
)
self
.
procs
=
[
Prefetch
Data
.
_Worker
(
self
.
ds
,
self
.
queue
)
for
_
in
range
(
self
.
nr_proc
)]
ensure_proc_terminate
(
self
.
procs
)
start_proc_mask_signal
(
self
.
procs
)
...
...
@@ -137,34 +180,12 @@ class PrefetchData(ProxyDataFlow):
pass
class
PrefetchProcessZMQ
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
conn_name
,
hwm
):
super
(
PrefetchProcessZMQ
,
self
)
.
__init__
()
self
.
ds
=
ds
self
.
conn_name
=
conn_name
self
.
hwm
=
hwm
def
run
(
self
):
self
.
ds
.
reset_state
()
context
=
zmq
.
Context
()
socket
=
context
.
socket
(
zmq
.
PUSH
)
socket
.
set_hwm
(
self
.
hwm
)
socket
.
connect
(
self
.
conn_name
)
try
:
while
True
:
for
dp
in
self
.
ds
.
get_data
():
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
# sigint could still propagate here, e.g. when nested
except
KeyboardInterrupt
:
pass
class
PrefetchDataZMQ
(
ProxyDataFlow
):
class
PrefetchDataZMQ
(
_MultiProcessZMQDataFlow
):
"""
Prefetch data from a DataFlow using multiple processes, with ZeroMQ for
communication.
It will fork the
process calling
:meth:`reset_state()`,
collect datapoints from `ds` in each process by ZeroMQ IPC pipe.
It will fork the
calling process of
:meth:`reset_state()`,
and
collect datapoints from `ds` in each process by ZeroMQ IPC pipe.
Note:
1. An iterator cannot run faster automatically -- what's happenning is
...
...
@@ -194,6 +215,28 @@ class PrefetchDataZMQ(ProxyDataFlow):
which points to a local directory.
5. Calling `reset_state()` more than once is a no-op, i.e. the worker processes won't get called.
"""
class
_Worker
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
conn_name
,
hwm
):
super
(
PrefetchDataZMQ
.
_Worker
,
self
)
.
__init__
()
self
.
ds
=
ds
self
.
conn_name
=
conn_name
self
.
hwm
=
hwm
def
run
(
self
):
self
.
ds
.
reset_state
()
context
=
zmq
.
Context
()
socket
=
context
.
socket
(
zmq
.
PUSH
)
socket
.
set_hwm
(
self
.
hwm
)
socket
.
connect
(
self
.
conn_name
)
try
:
while
True
:
for
dp
in
self
.
ds
.
get_data
():
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
# sigint could still propagate here, e.g. when nested
except
KeyboardInterrupt
:
pass
def
__init__
(
self
,
ds
,
nr_proc
=
1
,
hwm
=
50
):
"""
Args:
...
...
@@ -201,12 +244,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
nr_proc (int): number of processes to use.
hwm (int): the zmq "high-water mark" (queue size) for both sender and receiver.
"""
assert
os
.
name
!=
'nt'
,
"PrefetchDataZMQ doesn't support windows! PrefetchData might work sometimes."
super
(
PrefetchDataZMQ
,
self
)
.
__init__
(
ds
)
try
:
self
.
_size
=
ds
.
size
()
except
NotImplementedError
:
self
.
_size
=
-
1
self
.
nr_proc
=
nr_proc
self
.
_hwm
=
hwm
...
...
@@ -223,42 +262,16 @@ class PrefetchDataZMQ(ProxyDataFlow):
break
yield
self
.
_recv
()
def
reset_state
(
self
):
"""
All forked dataflows are reset **once and only once** in spawned processes.
Nothing more can be done when calling this method.
"""
if
self
.
_reset_done
:
return
self
.
_reset_done
=
True
def
_reset_once
(
self
):
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
socket
.
set_hwm
(
self
.
_hwm
)
pipename
=
_get_pipe_name
(
'dataflow'
)
_bind_guard
(
self
.
socket
,
pipename
)
self
.
procs
=
[
PrefetchProcessZMQ
(
self
.
ds
,
pipename
,
self
.
_hwm
)
for
_
in
range
(
self
.
nr_proc
)]
self
.
_procs
=
[
PrefetchDataZMQ
.
_Worker
(
self
.
ds
,
pipename
,
self
.
_hwm
)
for
_
in
range
(
self
.
nr_proc
)]
self
.
_start_processes
()
# __del__ not guranteed to get called at exit
import
atexit
atexit
.
register
(
lambda
x
:
x
.
__del__
(),
self
)
def
_start_processes
(
self
):
start_proc_mask_signal
(
self
.
procs
)
def
__del__
(
self
):
if
not
self
.
_reset_done
:
return
if
not
self
.
context
.
closed
:
self
.
context
.
destroy
(
0
)
for
x
in
self
.
procs
:
x
.
terminate
()
try
:
print
(
"PrefetchDataZMQ successfully cleaned-up."
)
except
:
pass
class
PrefetchOnGPUs
(
PrefetchDataZMQ
):
...
...
@@ -279,7 +292,7 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
def
_start_processes
(
self
):
with
mask_sigint
():
for
gpu
,
proc
in
zip
(
self
.
gpus
,
self
.
procs
):
for
gpu
,
proc
in
zip
(
self
.
gpus
,
self
.
_
procs
):
with
change_gpu
(
gpu
):
proc
.
start
()
...
...
@@ -307,14 +320,13 @@ class MultiThreadMapData(ProxyDataFlow):
is guranteed to produce the exact set which `df.get_data()`
produces. Although the order of data still isn't preserved.
"""
class
_Worker
Thread
(
StoppableThread
):
def
__init__
(
self
,
inq
,
outq
,
evt
,
map_func
,
strict
):
super
(
MultiThreadMapData
.
_Worker
Thread
,
self
)
.
__init__
(
evt
)
class
_Worker
(
StoppableThread
):
def
__init__
(
self
,
inq
,
outq
,
evt
,
map_func
):
super
(
MultiThreadMapData
.
_Worker
,
self
)
.
__init__
(
evt
)
self
.
inq
=
inq
self
.
outq
=
outq
self
.
func
=
map_func
self
.
daemon
=
True
self
.
_strict
=
strict
def
run
(
self
):
try
:
...
...
@@ -322,13 +334,8 @@ class MultiThreadMapData(ProxyDataFlow):
dp
=
self
.
queue_get_stoppable
(
self
.
inq
)
if
self
.
stopped
():
return
dp
=
self
.
func
(
dp
)
if
dp
is
not
None
:
self
.
outq
.
put
(
dp
)
else
:
assert
not
self
.
_strict
,
\
"[MultiThreadMapData] Map function cannot return None when strict mode is used."
# cannot ignore None here. will lead to unsynced send/recv
self
.
outq
.
put
(
self
.
func
(
dp
))
except
:
if
self
.
stopped
():
pass
# skip duplicated error messages
...
...
@@ -348,7 +355,6 @@ class MultiThreadMapData(ProxyDataFlow):
"""
super
(
MultiThreadMapData
,
self
)
.
__init__
(
ds
)
self
.
_iter_ds
=
ds
self
.
_strict
=
strict
self
.
nr_thread
=
nr_thread
self
.
buffer_size
=
buffer_size
...
...
@@ -366,13 +372,13 @@ class MultiThreadMapData(ProxyDataFlow):
self
.
_in_queue
=
queue
.
Queue
()
self
.
_out_queue
=
queue
.
Queue
()
self
.
_evt
=
threading
.
Event
()
self
.
_threads
=
[
MultiThreadMapData
.
_Worker
Thread
(
self
.
_in_queue
,
self
.
_out_queue
,
self
.
_evt
,
self
.
map_func
,
self
.
_strict
)
self
.
_threads
=
[
MultiThreadMapData
.
_Worker
(
self
.
_in_queue
,
self
.
_out_queue
,
self
.
_evt
,
self
.
map_func
)
for
_
in
range
(
self
.
nr_thread
)]
for
t
in
self
.
_threads
:
t
.
start
()
self
.
_iter
=
self
.
_iter_
ds
.
get_data
()
self
.
_iter
=
self
.
ds
.
get_data
()
self
.
_guard
=
DataFlowReentrantGuard
()
# only call once, to ensure inq+outq has a total of buffer_size elements
...
...
@@ -390,24 +396,31 @@ class MultiThreadMapData(ProxyDataFlow):
logger
.
error
(
"[MultiThreadMapData] buffer_size cannot be larger than the size of the DataFlow!"
)
raise
def
_recv
(
self
):
ret
=
self
.
_out_queue
.
get
()
if
ret
is
None
:
assert
not
self
.
_strict
,
\
"[MultiThreadMapData] Map function cannot return None when strict mode is used."
return
ret
def
get_data
(
self
):
with
self
.
_guard
:
for
dp
in
self
.
_iter
:
self
.
_in_queue
.
put
(
dp
)
yield
self
.
_
out_queue
.
get
()
yield
self
.
_
recv
()
self
.
_iter
=
self
.
_iter_
ds
.
get_data
()
self
.
_iter
=
self
.
ds
.
get_data
()
if
self
.
_strict
:
# first call get() to clear the queues, then fill
for
k
in
range
(
self
.
buffer_size
):
dp
=
self
.
_
out_queue
.
get
()
dp
=
self
.
_
recv
()
if
k
==
self
.
buffer_size
-
1
:
self
.
_fill_buffer
()
yield
dp
else
:
for
_
in
range
(
self
.
buffer_size
):
self
.
_in_queue
.
put
(
next
(
self
.
_iter
))
yield
self
.
_
out_queue
.
get
()
yield
self
.
_
recv
()
def
__del__
(
self
):
if
self
.
_evt
is
not
None
:
...
...
@@ -420,7 +433,11 @@ class MultiThreadMapData(ProxyDataFlow):
ThreadedMapData
=
MultiThreadMapData
class
MultiProcessMapDataZMQ
(
ProxyDataFlow
):
class
MultiProcessMapDataZMQ
(
_MultiProcessZMQDataFlow
):
"""
Same as :class:`MapData`, but start processes to run the mapping function,
and communicate with ZeroMQ pipe.
"""
class
_Worker
(
mp
.
Process
):
def
__init__
(
self
,
identity
,
map_func
,
pipename
,
hwm
):
super
(
MultiProcessMapDataZMQ
.
_Worker
,
self
)
.
__init__
()
...
...
@@ -442,57 +459,50 @@ class MultiProcessMapDataZMQ(ProxyDataFlow):
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
def
__init__
(
self
,
ds
,
nr_proc
,
map_func
,
buffer_size
=
200
):
"""
Args:
ds (DataFlow): the dataflow to map
nr_proc(int): number of threads to use
map_func (callable): datapoint -> datapoint | None
buffer_size (int): number of datapoints in the buffer
"""
super
(
MultiProcessMapDataZMQ
,
self
)
.
__init__
(
ds
)
self
.
nr_proc
=
nr_proc
self
.
map_func
=
map_func
self
.
_buffer_size
=
buffer_size
self
.
buffer_size
=
buffer_size
self
.
_procs
=
[]
self
.
_reset_done
=
False
try
:
self
.
_size
=
ds
.
size
()
except
NotImplementedError
:
self
.
_size
=
-
1
def
reset_state
(
self
):
if
self
.
_reset_done
:
return
self
.
_reset_done
=
True
self
.
_guard
=
DataFlowReentrantGuard
()
def
_reset_once
(
self
):
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
socket
.
set_hwm
(
self
.
_
buffer_size
*
2
)
self
.
socket
.
set_hwm
(
self
.
buffer_size
*
2
)
pipename
=
_get_pipe_name
(
'dataflow-map'
)
_bind_guard
(
self
.
socket
,
pipename
)
self
.
_proc_ids
=
[
u'{}'
.
format
(
k
)
.
encode
(
'utf-8'
)
for
k
in
range
(
self
.
nr_proc
)]
worker_hwm
=
int
(
self
.
_
buffer_size
*
2
//
self
.
nr_proc
)
worker_hwm
=
int
(
self
.
buffer_size
*
2
//
self
.
nr_proc
)
self
.
_procs
=
[
MultiProcessMapDataZMQ
.
_Worker
(
self
.
_proc_ids
[
k
],
self
.
map_func
,
pipename
,
worker_hwm
)
for
k
in
range
(
self
.
nr_proc
)]
self
.
ds
.
reset_state
()
self
.
_iter
_ds
=
_repeat_iter
(
lambda
:
self
.
ds
.
get_data
()
)
self
.
_iter
=
self
.
ds
.
get_data
(
)
self
.
_iter_worker
=
_repeat_iter
(
lambda
:
iter
(
self
.
_proc_ids
))
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
_start_processes
()
self
.
_fill_buffer
()
import
atexit
atexit
.
register
(
lambda
x
:
x
.
__del__
(),
self
)
def
_fill_buffer
(
self
):
# Filling the buffer.
for
_
in
range
(
self
.
_buffer_size
):
self
.
_send
()
def
_start_processes
(
self
):
start_proc_mask_signal
(
self
.
_procs
)
try
:
for
_
in
range
(
self
.
buffer_size
):
self
.
_send
(
next
(
self
.
_iter
))
except
StopIteration
:
logger
.
error
(
"[MultiProcessMapData] buffer_size cannot be larger than the size of the DataFlow!"
)
raise
def
_send
(
self
):
dp
=
next
(
self
.
_iter_ds
)
def
_send
(
self
,
dp
):
# round-robin assignment
worker
=
next
(
self
.
_iter_worker
)
msg
=
[
worker
,
dumps
(
dp
)]
...
...
@@ -505,40 +515,32 @@ class MultiProcessMapDataZMQ(ProxyDataFlow):
def
get_data
(
self
):
with
self
.
_guard
,
_zmq_catch_error
(
'MultiProcessMapData'
):
for
k
in
itertools
.
count
():
if
self
.
_size
>
0
and
k
>=
self
.
_size
:
break
for
dp
in
self
.
_iter
:
self
.
_send
(
dp
)
yield
self
.
_recv
()
self
.
_send
()
def
__del__
(
self
):
if
not
self
.
_reset_done
:
return
if
not
self
.
context
.
closed
:
self
.
context
.
destroy
(
0
)
for
x
in
self
.
_procs
:
x
.
terminate
()
try
:
print
(
"MultiProcessMapData successfully cleaned-up."
)
except
:
pass
self
.
_iter
=
self
.
ds
.
get_data
()
# refresh
for
_
in
range
(
self
.
buffer_size
):
self
.
_send
(
next
(
self
.
_iter
))
yield
self
.
_recv
()
MultiProcessMapData
=
MultiProcessMapDataZMQ
# alias
if
__name__
==
'__main__'
:
from
.base
import
DataFlow
class
Zero
(
DataFlow
):
def
__init__
(
self
,
size
):
self
.
_size
=
size
class
Naive
(
DataFlow
):
def
get_data
(
self
):
for
k
in
range
(
1000
):
for
k
in
range
(
self
.
_size
):
yield
[
0
]
def
size
(
self
):
return
100
return
self
.
_size
ds
=
Naive
(
)
ds
=
Zero
(
300
)
ds
=
MultiProcessMapData
(
ds
,
3
,
lambda
x
:
[
x
[
0
]
+
1
])
ds
.
reset_state
()
for
k
in
ds
.
get_data
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment