Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6a4f4ee2
Commit
6a4f4ee2
authored
Oct 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
initial version of MultiProcessMapData (#414)
parent
9eaf6e92
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
202 additions
and
51 deletions
+202
-51
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+202
-51
No files found.
tensorpack/dataflow/prefetch.py
View file @
6a4f4ee2
...
...
@@ -4,6 +4,7 @@
from
__future__
import
print_function
import
threading
from
contextlib
import
contextmanager
import
multiprocessing
as
mp
import
itertools
from
six.moves
import
range
,
zip
,
queue
...
...
@@ -21,7 +22,47 @@ from ..utils import logger
from
..utils.gpu
import
change_gpu
__all__
=
[
'PrefetchData'
,
'PrefetchDataZMQ'
,
'PrefetchOnGPUs'
,
'ThreadedMapData'
]
'ThreadedMapData'
,
'MultiThreadMapData'
,
'MultiProcessMapData'
]
def
_repeat_iter
(
get_itr
):
while
True
:
for
x
in
get_itr
():
yield
x
def
_bind_guard
(
sock
,
name
):
try
:
sock
.
bind
(
name
)
except
zmq
.
ZMQError
:
logger
.
error
(
"ZMQError in socket.bind(). Perhaps you're
\
using pipes on a non-local file system. See documentation of PrefetchDataZMQ for more information."
)
raise
def
_get_pipe_name
(
name
):
pipedir
=
os
.
environ
.
get
(
'TENSORPACK_PIPEDIR'
,
'.'
)
assert
os
.
path
.
isdir
(
pipedir
),
pipedir
pipename
=
"ipc://{}/{}-pipe-"
.
format
(
pipedir
.
rstrip
(
'/'
),
name
)
+
str
(
uuid
.
uuid1
())[:
6
]
return
pipename
@
contextmanager
def
_zmq_catch_error
(
name
):
try
:
yield
except
zmq
.
ContextTerminated
:
logger
.
info
(
"[{}] Context terminated."
.
format
(
name
))
raise
DataFlowTerminated
()
except
zmq
.
ZMQError
as
e
:
if
e
.
errno
==
errno
.
ENOTSOCK
:
# socket closed
logger
.
info
(
"[{}] Socket closed."
.
format
(
name
))
raise
DataFlowTerminated
()
else
:
raise
except
:
raise
class
PrefetchProcess
(
mp
.
Process
):
...
...
@@ -105,14 +146,14 @@ class PrefetchProcessZMQ(mp.Process):
def
run
(
self
):
self
.
ds
.
reset_state
()
self
.
context
=
zmq
.
Context
()
s
elf
.
socket
=
self
.
context
.
socket
(
zmq
.
PUSH
)
s
elf
.
s
ocket
.
set_hwm
(
self
.
hwm
)
s
elf
.
s
ocket
.
connect
(
self
.
conn_name
)
context
=
zmq
.
Context
()
s
ocket
=
context
.
socket
(
zmq
.
PUSH
)
socket
.
set_hwm
(
self
.
hwm
)
socket
.
connect
(
self
.
conn_name
)
try
:
while
True
:
for
dp
in
self
.
ds
.
get_data
():
s
elf
.
s
ocket
.
send
(
dumps
(
dp
),
copy
=
False
)
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
# sigint could still propagate here, e.g. when nested
except
KeyboardInterrupt
:
pass
...
...
@@ -170,53 +211,34 @@ class PrefetchDataZMQ(ProxyDataFlow):
self
.
_hwm
=
hwm
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
_setup_done
=
False
self
.
_reset_done
=
False
def
_recv
(
self
):
return
loads
(
self
.
socket
.
recv
(
copy
=
False
)
.
bytes
)
def
get_data
(
self
):
with
self
.
_guard
:
try
:
for
k
in
itertools
.
count
():
if
self
.
_size
>
0
and
k
>=
self
.
_size
:
break
dp
=
loads
(
self
.
socket
.
recv
(
copy
=
False
)
.
bytes
)
yield
dp
except
zmq
.
ContextTerminated
:
logger
.
info
(
"[Prefetch Master] Context terminated."
)
raise
DataFlowTerminated
()
except
zmq
.
ZMQError
as
e
:
if
e
.
errno
==
errno
.
ENOTSOCK
:
# socket closed
logger
.
info
(
"[Prefetch Master] Socket closed."
)
raise
DataFlowTerminated
()
else
:
raise
except
:
raise
with
self
.
_guard
,
_zmq_catch_error
(
'PrefetchDataZMQ'
):
for
k
in
itertools
.
count
():
if
self
.
_size
>
0
and
k
>=
self
.
_size
:
break
yield
self
.
_recv
()
def
reset_state
(
self
):
"""
All forked dataflows are reset **once and only once** in spawned processes.
Nothing more can be done when calling this method.
"""
if
self
.
_
setup
_done
:
if
self
.
_
reset
_done
:
return
self
.
_
setup
_done
=
True
self
.
_
reset
_done
=
True
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
pipedir
=
os
.
environ
.
get
(
'TENSORPACK_PIPEDIR'
,
'.'
)
assert
os
.
path
.
isdir
(
pipedir
),
pipedir
self
.
pipename
=
"ipc://{}/dataflow-pipe-"
.
format
(
pipedir
.
rstrip
(
'/'
))
+
str
(
uuid
.
uuid1
())[:
6
]
self
.
socket
.
set_hwm
(
self
.
_hwm
)
try
:
self
.
socket
.
bind
(
self
.
pipename
)
except
zmq
.
ZMQError
:
logger
.
error
(
"ZMQError in socket.bind(). Perhaps you're
\
using pipes on a non-local file system. See documentation of PrefetchDataZMQ for more information."
)
raise
pipename
=
_get_pipe_name
(
'dataflow'
)
_bind_guard
(
self
.
socket
,
pipename
)
self
.
procs
=
[
PrefetchProcessZMQ
(
self
.
ds
,
self
.
pipename
,
self
.
_hwm
)
self
.
procs
=
[
PrefetchProcessZMQ
(
self
.
ds
,
pipename
,
self
.
_hwm
)
for
_
in
range
(
self
.
nr_proc
)]
self
.
_start_processes
()
# __del__ not guranteed to get called at exit
...
...
@@ -227,15 +249,14 @@ class PrefetchDataZMQ(ProxyDataFlow):
start_proc_mask_signal
(
self
.
procs
)
def
__del__
(
self
):
if
not
self
.
_
setup
_done
:
if
not
self
.
_
reset
_done
:
return
if
not
self
.
context
.
closed
:
self
.
context
.
destroy
(
0
)
for
x
in
self
.
procs
:
x
.
terminate
()
try
:
# TODO test if logger here would overwrite log file
print
(
"Prefetch process exited."
)
print
(
"PrefetchDataZMQ successfully cleaned-up."
)
except
:
pass
...
...
@@ -263,7 +284,7 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
proc
.
start
()
class
Threade
dMapData
(
ProxyDataFlow
):
class
MultiThrea
dMapData
(
ProxyDataFlow
):
"""
Same as :class:`MapData`, but start threads to run the mapping function.
This is useful when the mapping function is the bottleneck, but you don't
...
...
@@ -274,7 +295,7 @@ class ThreadedMapData(ProxyDataFlow):
should avoid starting many threads in your main process to reduce GIL contention.
The threads will only start in the process which calls :meth:`reset_state()`.
Therefore you can use ``PrefetchDataZMQ(
Threade
dMapData(...), 1)``
Therefore you can use ``PrefetchDataZMQ(
MultiThrea
dMapData(...), 1)``
to reduce GIL contention.
2. Threads run in parallel and can take different time to run the
...
...
@@ -282,13 +303,13 @@ class ThreadedMapData(ProxyDataFlow):
preserved, and datapoints from one pass of `df.get_data()` might get
mixed with datapoints from the next pass.
You can use **strict mode**, where `
Threade
dMapData.get_data()`
You can use **strict mode**, where `
MultiThrea
dMapData.get_data()`
is guranteed to produce the exact set which `df.get_data()`
produces. Although the order of data still isn't preserved.
"""
class
_WorkerThread
(
StoppableThread
):
def
__init__
(
self
,
inq
,
outq
,
evt
,
map_func
,
strict
):
super
(
Threade
dMapData
.
_WorkerThread
,
self
)
.
__init__
(
evt
)
super
(
MultiThrea
dMapData
.
_WorkerThread
,
self
)
.
__init__
(
evt
)
self
.
inq
=
inq
self
.
outq
=
outq
self
.
func
=
map_func
...
...
@@ -307,7 +328,7 @@ class ThreadedMapData(ProxyDataFlow):
self
.
outq
.
put
(
dp
)
else
:
assert
not
self
.
_strict
,
\
"[
Threade
dMapData] Map function cannot return None when strict mode is used."
"[
MultiThrea
dMapData] Map function cannot return None when strict mode is used."
except
:
if
self
.
stopped
():
pass
# skip duplicated error messages
...
...
@@ -325,7 +346,7 @@ class ThreadedMapData(ProxyDataFlow):
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
super
(
Threade
dMapData
,
self
)
.
__init__
(
ds
)
super
(
MultiThrea
dMapData
,
self
)
.
__init__
(
ds
)
self
.
_iter_ds
=
ds
self
.
_strict
=
strict
...
...
@@ -336,7 +357,7 @@ class ThreadedMapData(ProxyDataFlow):
self
.
_evt
=
None
def
reset_state
(
self
):
super
(
Threade
dMapData
,
self
)
.
reset_state
()
super
(
MultiThrea
dMapData
,
self
)
.
reset_state
()
if
self
.
_threads
:
self
.
_threads
[
0
]
.
stop
()
for
t
in
self
.
_threads
:
...
...
@@ -345,7 +366,7 @@ class ThreadedMapData(ProxyDataFlow):
self
.
_in_queue
=
queue
.
Queue
()
self
.
_out_queue
=
queue
.
Queue
()
self
.
_evt
=
threading
.
Event
()
self
.
_threads
=
[
Threade
dMapData
.
_WorkerThread
(
self
.
_threads
=
[
MultiThrea
dMapData
.
_WorkerThread
(
self
.
_in_queue
,
self
.
_out_queue
,
self
.
_evt
,
self
.
map_func
,
self
.
_strict
)
for
_
in
range
(
self
.
nr_thread
)]
for
t
in
self
.
_threads
:
...
...
@@ -366,7 +387,7 @@ class ThreadedMapData(ProxyDataFlow):
for
_
in
range
(
n
):
self
.
_in_queue
.
put
(
next
(
self
.
_iter
))
except
StopIteration
:
logger
.
error
(
"[
Threade
dMapData] buffer_size cannot be larger than the size of the DataFlow!"
)
logger
.
error
(
"[
MultiThrea
dMapData] buffer_size cannot be larger than the size of the DataFlow!"
)
raise
def
get_data
(
self
):
...
...
@@ -393,3 +414,133 @@ class ThreadedMapData(ProxyDataFlow):
self
.
_evt
.
set
()
for
p
in
self
.
_threads
:
p
.
join
()
# TODO deprecated
ThreadedMapData
=
MultiThreadMapData
class
MultiProcessMapDataZMQ
(
ProxyDataFlow
):
class
_Worker
(
mp
.
Process
):
def
__init__
(
self
,
identity
,
map_func
,
pipename
,
hwm
):
super
(
MultiProcessMapDataZMQ
.
_Worker
,
self
)
.
__init__
()
self
.
identity
=
identity
self
.
map_func
=
map_func
self
.
pipename
=
pipename
self
.
hwm
=
hwm
def
run
(
self
):
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
DEALER
)
socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
socket
.
set_hwm
(
self
.
hwm
)
socket
.
connect
(
self
.
pipename
)
while
True
:
dp
=
loads
(
socket
.
recv
(
copy
=
False
)
.
bytes
)
dp
=
self
.
map_func
(
dp
)
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
def
__init__
(
self
,
ds
,
nr_proc
,
map_func
,
buffer_size
=
200
):
super
(
MultiProcessMapDataZMQ
,
self
)
.
__init__
(
ds
)
self
.
nr_proc
=
nr_proc
self
.
map_func
=
map_func
self
.
_buffer_size
=
buffer_size
self
.
_procs
=
[]
self
.
_reset_done
=
False
try
:
self
.
_size
=
ds
.
size
()
except
NotImplementedError
:
self
.
_size
=
-
1
def
reset_state
(
self
):
if
self
.
_reset_done
:
return
self
.
_reset_done
=
True
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
socket
.
set_hwm
(
self
.
_buffer_size
*
2
)
pipename
=
_get_pipe_name
(
'dataflow-map'
)
_bind_guard
(
self
.
socket
,
pipename
)
self
.
_proc_ids
=
[
u'{}'
.
format
(
k
)
.
encode
(
'utf-8'
)
for
k
in
range
(
self
.
nr_proc
)]
worker_hwm
=
int
(
self
.
_buffer_size
*
2
//
self
.
nr_proc
)
self
.
_procs
=
[
MultiProcessMapDataZMQ
.
_Worker
(
self
.
_proc_ids
[
k
],
self
.
map_func
,
pipename
,
worker_hwm
)
for
k
in
range
(
self
.
nr_proc
)]
self
.
ds
.
reset_state
()
self
.
_iter_ds
=
_repeat_iter
(
lambda
:
self
.
ds
.
get_data
())
self
.
_iter_worker
=
_repeat_iter
(
lambda
:
iter
(
self
.
_proc_ids
))
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
_start_processes
()
self
.
_fill_buffer
()
import
atexit
atexit
.
register
(
lambda
x
:
x
.
__del__
(),
self
)
def
_fill_buffer
(
self
):
# Filling the buffer.
for
_
in
range
(
self
.
_buffer_size
):
self
.
_send
()
def
_start_processes
(
self
):
start_proc_mask_signal
(
self
.
_procs
)
def
_send
(
self
):
dp
=
next
(
self
.
_iter_ds
)
# round-robin assignment
worker
=
next
(
self
.
_iter_worker
)
msg
=
[
worker
,
dumps
(
dp
)]
self
.
socket
.
send_multipart
(
msg
,
copy
=
False
)
def
_recv
(
self
):
msg
=
self
.
socket
.
recv_multipart
(
copy
=
False
)
dp
=
loads
(
msg
[
1
]
.
bytes
)
return
dp
def
get_data
(
self
):
with
self
.
_guard
,
_zmq_catch_error
(
'MultiProcessMapData'
):
for
k
in
itertools
.
count
():
if
self
.
_size
>
0
and
k
>=
self
.
_size
:
break
yield
self
.
_recv
()
self
.
_send
()
def
__del__
(
self
):
if
not
self
.
_reset_done
:
return
if
not
self
.
context
.
closed
:
self
.
context
.
destroy
(
0
)
for
x
in
self
.
_procs
:
x
.
terminate
()
try
:
print
(
"MultiProcessMapData successfully cleaned-up."
)
except
:
pass
MultiProcessMapData
=
MultiProcessMapDataZMQ
# alias
if
__name__
==
'__main__'
:
from
.base
import
DataFlow
class
Naive
(
DataFlow
):
def
get_data
(
self
):
for
k
in
range
(
1000
):
yield
[
0
]
def
size
(
self
):
return
100
ds
=
Naive
()
ds
=
MultiProcessMapData
(
ds
,
3
,
lambda
x
:
[
x
[
0
]
+
1
])
ds
.
reset_state
()
for
k
in
ds
.
get_data
():
print
(
"Bang!"
,
k
)
print
(
"END!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment