Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6a4f4ee2
Commit
6a4f4ee2
authored
Oct 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
initial version of MultiProcessMapData (#414)
parent
9eaf6e92
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
202 additions
and
51 deletions
+202
-51
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+202
-51
No files found.
tensorpack/dataflow/prefetch.py
View file @
6a4f4ee2
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
threading
import
threading
from
contextlib
import
contextmanager
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
itertools
import
itertools
from
six.moves
import
range
,
zip
,
queue
from
six.moves
import
range
,
zip
,
queue
...
@@ -21,7 +22,47 @@ from ..utils import logger
...
@@ -21,7 +22,47 @@ from ..utils import logger
from
..utils.gpu
import
change_gpu
from
..utils.gpu
import
change_gpu
__all__
=
[
'PrefetchData'
,
'PrefetchDataZMQ'
,
'PrefetchOnGPUs'
,
__all__
=
[
'PrefetchData'
,
'PrefetchDataZMQ'
,
'PrefetchOnGPUs'
,
'ThreadedMapData'
]
'ThreadedMapData'
,
'MultiThreadMapData'
,
'MultiProcessMapData'
]
def
_repeat_iter
(
get_itr
):
while
True
:
for
x
in
get_itr
():
yield
x
def
_bind_guard
(
sock
,
name
):
try
:
sock
.
bind
(
name
)
except
zmq
.
ZMQError
:
logger
.
error
(
"ZMQError in socket.bind(). Perhaps you're
\
using pipes on a non-local file system. See documentation of PrefetchDataZMQ for more information."
)
raise
def
_get_pipe_name
(
name
):
pipedir
=
os
.
environ
.
get
(
'TENSORPACK_PIPEDIR'
,
'.'
)
assert
os
.
path
.
isdir
(
pipedir
),
pipedir
pipename
=
"ipc://{}/{}-pipe-"
.
format
(
pipedir
.
rstrip
(
'/'
),
name
)
+
str
(
uuid
.
uuid1
())[:
6
]
return
pipename
@
contextmanager
def
_zmq_catch_error
(
name
):
try
:
yield
except
zmq
.
ContextTerminated
:
logger
.
info
(
"[{}] Context terminated."
.
format
(
name
))
raise
DataFlowTerminated
()
except
zmq
.
ZMQError
as
e
:
if
e
.
errno
==
errno
.
ENOTSOCK
:
# socket closed
logger
.
info
(
"[{}] Socket closed."
.
format
(
name
))
raise
DataFlowTerminated
()
else
:
raise
except
:
raise
class
PrefetchProcess
(
mp
.
Process
):
class
PrefetchProcess
(
mp
.
Process
):
...
@@ -105,14 +146,14 @@ class PrefetchProcessZMQ(mp.Process):
...
@@ -105,14 +146,14 @@ class PrefetchProcessZMQ(mp.Process):
def
run
(
self
):
def
run
(
self
):
self
.
ds
.
reset_state
()
self
.
ds
.
reset_state
()
self
.
context
=
zmq
.
Context
()
context
=
zmq
.
Context
()
s
elf
.
socket
=
self
.
context
.
socket
(
zmq
.
PUSH
)
s
ocket
=
context
.
socket
(
zmq
.
PUSH
)
s
elf
.
s
ocket
.
set_hwm
(
self
.
hwm
)
socket
.
set_hwm
(
self
.
hwm
)
s
elf
.
s
ocket
.
connect
(
self
.
conn_name
)
socket
.
connect
(
self
.
conn_name
)
try
:
try
:
while
True
:
while
True
:
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
s
elf
.
s
ocket
.
send
(
dumps
(
dp
),
copy
=
False
)
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
# sigint could still propagate here, e.g. when nested
# sigint could still propagate here, e.g. when nested
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
pass
pass
...
@@ -170,53 +211,34 @@ class PrefetchDataZMQ(ProxyDataFlow):
...
@@ -170,53 +211,34 @@ class PrefetchDataZMQ(ProxyDataFlow):
self
.
_hwm
=
hwm
self
.
_hwm
=
hwm
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
_setup_done
=
False
self
.
_reset_done
=
False
def
_recv
(
self
):
return
loads
(
self
.
socket
.
recv
(
copy
=
False
)
.
bytes
)
def
get_data
(
self
):
def
get_data
(
self
):
with
self
.
_guard
:
with
self
.
_guard
,
_zmq_catch_error
(
'PrefetchDataZMQ'
):
try
:
for
k
in
itertools
.
count
():
for
k
in
itertools
.
count
():
if
self
.
_size
>
0
and
k
>=
self
.
_size
:
if
self
.
_size
>
0
and
k
>=
self
.
_size
:
break
break
dp
=
loads
(
self
.
socket
.
recv
(
copy
=
False
)
.
bytes
)
yield
self
.
_recv
()
yield
dp
except
zmq
.
ContextTerminated
:
logger
.
info
(
"[Prefetch Master] Context terminated."
)
raise
DataFlowTerminated
()
except
zmq
.
ZMQError
as
e
:
if
e
.
errno
==
errno
.
ENOTSOCK
:
# socket closed
logger
.
info
(
"[Prefetch Master] Socket closed."
)
raise
DataFlowTerminated
()
else
:
raise
except
:
raise
def
reset_state
(
self
):
def
reset_state
(
self
):
"""
"""
All forked dataflows are reset **once and only once** in spawned processes.
All forked dataflows are reset **once and only once** in spawned processes.
Nothing more can be done when calling this method.
Nothing more can be done when calling this method.
"""
"""
if
self
.
_
setup
_done
:
if
self
.
_
reset
_done
:
return
return
self
.
_
setup
_done
=
True
self
.
_
reset
_done
=
True
self
.
context
=
zmq
.
Context
()
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
pipedir
=
os
.
environ
.
get
(
'TENSORPACK_PIPEDIR'
,
'.'
)
assert
os
.
path
.
isdir
(
pipedir
),
pipedir
self
.
pipename
=
"ipc://{}/dataflow-pipe-"
.
format
(
pipedir
.
rstrip
(
'/'
))
+
str
(
uuid
.
uuid1
())[:
6
]
self
.
socket
.
set_hwm
(
self
.
_hwm
)
self
.
socket
.
set_hwm
(
self
.
_hwm
)
try
:
pipename
=
_get_pipe_name
(
'dataflow'
)
self
.
socket
.
bind
(
self
.
pipename
)
_bind_guard
(
self
.
socket
,
pipename
)
except
zmq
.
ZMQError
:
logger
.
error
(
"ZMQError in socket.bind(). Perhaps you're
\
using pipes on a non-local file system. See documentation of PrefetchDataZMQ for more information."
)
raise
self
.
procs
=
[
PrefetchProcessZMQ
(
self
.
ds
,
self
.
pipename
,
self
.
_hwm
)
self
.
procs
=
[
PrefetchProcessZMQ
(
self
.
ds
,
pipename
,
self
.
_hwm
)
for
_
in
range
(
self
.
nr_proc
)]
for
_
in
range
(
self
.
nr_proc
)]
self
.
_start_processes
()
self
.
_start_processes
()
# __del__ not guranteed to get called at exit
# __del__ not guranteed to get called at exit
...
@@ -227,15 +249,14 @@ class PrefetchDataZMQ(ProxyDataFlow):
...
@@ -227,15 +249,14 @@ class PrefetchDataZMQ(ProxyDataFlow):
start_proc_mask_signal
(
self
.
procs
)
start_proc_mask_signal
(
self
.
procs
)
def
__del__
(
self
):
def
__del__
(
self
):
if
not
self
.
_
setup
_done
:
if
not
self
.
_
reset
_done
:
return
return
if
not
self
.
context
.
closed
:
if
not
self
.
context
.
closed
:
self
.
context
.
destroy
(
0
)
self
.
context
.
destroy
(
0
)
for
x
in
self
.
procs
:
for
x
in
self
.
procs
:
x
.
terminate
()
x
.
terminate
()
try
:
try
:
# TODO test if logger here would overwrite log file
print
(
"PrefetchDataZMQ successfully cleaned-up."
)
print
(
"Prefetch process exited."
)
except
:
except
:
pass
pass
...
@@ -263,7 +284,7 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
...
@@ -263,7 +284,7 @@ class PrefetchOnGPUs(PrefetchDataZMQ):
proc
.
start
()
proc
.
start
()
class
Threade
dMapData
(
ProxyDataFlow
):
class
MultiThrea
dMapData
(
ProxyDataFlow
):
"""
"""
Same as :class:`MapData`, but start threads to run the mapping function.
Same as :class:`MapData`, but start threads to run the mapping function.
This is useful when the mapping function is the bottleneck, but you don't
This is useful when the mapping function is the bottleneck, but you don't
...
@@ -274,7 +295,7 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -274,7 +295,7 @@ class ThreadedMapData(ProxyDataFlow):
should avoid starting many threads in your main process to reduce GIL contention.
should avoid starting many threads in your main process to reduce GIL contention.
The threads will only start in the process which calls :meth:`reset_state()`.
The threads will only start in the process which calls :meth:`reset_state()`.
Therefore you can use ``PrefetchDataZMQ(
Threade
dMapData(...), 1)``
Therefore you can use ``PrefetchDataZMQ(
MultiThrea
dMapData(...), 1)``
to reduce GIL contention.
to reduce GIL contention.
2. Threads run in parallel and can take different time to run the
2. Threads run in parallel and can take different time to run the
...
@@ -282,13 +303,13 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -282,13 +303,13 @@ class ThreadedMapData(ProxyDataFlow):
preserved, and datapoints from one pass of `df.get_data()` might get
preserved, and datapoints from one pass of `df.get_data()` might get
mixed with datapoints from the next pass.
mixed with datapoints from the next pass.
You can use **strict mode**, where `
Threade
dMapData.get_data()`
You can use **strict mode**, where `
MultiThrea
dMapData.get_data()`
is guranteed to produce the exact set which `df.get_data()`
is guranteed to produce the exact set which `df.get_data()`
produces. Although the order of data still isn't preserved.
produces. Although the order of data still isn't preserved.
"""
"""
class
_WorkerThread
(
StoppableThread
):
class
_WorkerThread
(
StoppableThread
):
def
__init__
(
self
,
inq
,
outq
,
evt
,
map_func
,
strict
):
def
__init__
(
self
,
inq
,
outq
,
evt
,
map_func
,
strict
):
super
(
Threade
dMapData
.
_WorkerThread
,
self
)
.
__init__
(
evt
)
super
(
MultiThrea
dMapData
.
_WorkerThread
,
self
)
.
__init__
(
evt
)
self
.
inq
=
inq
self
.
inq
=
inq
self
.
outq
=
outq
self
.
outq
=
outq
self
.
func
=
map_func
self
.
func
=
map_func
...
@@ -307,7 +328,7 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -307,7 +328,7 @@ class ThreadedMapData(ProxyDataFlow):
self
.
outq
.
put
(
dp
)
self
.
outq
.
put
(
dp
)
else
:
else
:
assert
not
self
.
_strict
,
\
assert
not
self
.
_strict
,
\
"[
Threade
dMapData] Map function cannot return None when strict mode is used."
"[
MultiThrea
dMapData] Map function cannot return None when strict mode is used."
except
:
except
:
if
self
.
stopped
():
if
self
.
stopped
():
pass
# skip duplicated error messages
pass
# skip duplicated error messages
...
@@ -325,7 +346,7 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -325,7 +346,7 @@ class ThreadedMapData(ProxyDataFlow):
buffer_size (int): number of datapoints in the buffer
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
strict (bool): use "strict mode", see notes above.
"""
"""
super
(
Threade
dMapData
,
self
)
.
__init__
(
ds
)
super
(
MultiThrea
dMapData
,
self
)
.
__init__
(
ds
)
self
.
_iter_ds
=
ds
self
.
_iter_ds
=
ds
self
.
_strict
=
strict
self
.
_strict
=
strict
...
@@ -336,7 +357,7 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -336,7 +357,7 @@ class ThreadedMapData(ProxyDataFlow):
self
.
_evt
=
None
self
.
_evt
=
None
def
reset_state
(
self
):
def
reset_state
(
self
):
super
(
Threade
dMapData
,
self
)
.
reset_state
()
super
(
MultiThrea
dMapData
,
self
)
.
reset_state
()
if
self
.
_threads
:
if
self
.
_threads
:
self
.
_threads
[
0
]
.
stop
()
self
.
_threads
[
0
]
.
stop
()
for
t
in
self
.
_threads
:
for
t
in
self
.
_threads
:
...
@@ -345,7 +366,7 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -345,7 +366,7 @@ class ThreadedMapData(ProxyDataFlow):
self
.
_in_queue
=
queue
.
Queue
()
self
.
_in_queue
=
queue
.
Queue
()
self
.
_out_queue
=
queue
.
Queue
()
self
.
_out_queue
=
queue
.
Queue
()
self
.
_evt
=
threading
.
Event
()
self
.
_evt
=
threading
.
Event
()
self
.
_threads
=
[
Threade
dMapData
.
_WorkerThread
(
self
.
_threads
=
[
MultiThrea
dMapData
.
_WorkerThread
(
self
.
_in_queue
,
self
.
_out_queue
,
self
.
_evt
,
self
.
map_func
,
self
.
_strict
)
self
.
_in_queue
,
self
.
_out_queue
,
self
.
_evt
,
self
.
map_func
,
self
.
_strict
)
for
_
in
range
(
self
.
nr_thread
)]
for
_
in
range
(
self
.
nr_thread
)]
for
t
in
self
.
_threads
:
for
t
in
self
.
_threads
:
...
@@ -366,7 +387,7 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -366,7 +387,7 @@ class ThreadedMapData(ProxyDataFlow):
for
_
in
range
(
n
):
for
_
in
range
(
n
):
self
.
_in_queue
.
put
(
next
(
self
.
_iter
))
self
.
_in_queue
.
put
(
next
(
self
.
_iter
))
except
StopIteration
:
except
StopIteration
:
logger
.
error
(
"[
Threade
dMapData] buffer_size cannot be larger than the size of the DataFlow!"
)
logger
.
error
(
"[
MultiThrea
dMapData] buffer_size cannot be larger than the size of the DataFlow!"
)
raise
raise
def
get_data
(
self
):
def
get_data
(
self
):
...
@@ -393,3 +414,133 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -393,3 +414,133 @@ class ThreadedMapData(ProxyDataFlow):
self
.
_evt
.
set
()
self
.
_evt
.
set
()
for
p
in
self
.
_threads
:
for
p
in
self
.
_threads
:
p
.
join
()
p
.
join
()
# TODO deprecated
ThreadedMapData
=
MultiThreadMapData
class
MultiProcessMapDataZMQ
(
ProxyDataFlow
):
class
_Worker
(
mp
.
Process
):
def
__init__
(
self
,
identity
,
map_func
,
pipename
,
hwm
):
super
(
MultiProcessMapDataZMQ
.
_Worker
,
self
)
.
__init__
()
self
.
identity
=
identity
self
.
map_func
=
map_func
self
.
pipename
=
pipename
self
.
hwm
=
hwm
def
run
(
self
):
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
DEALER
)
socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
socket
.
set_hwm
(
self
.
hwm
)
socket
.
connect
(
self
.
pipename
)
while
True
:
dp
=
loads
(
socket
.
recv
(
copy
=
False
)
.
bytes
)
dp
=
self
.
map_func
(
dp
)
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
def
__init__
(
self
,
ds
,
nr_proc
,
map_func
,
buffer_size
=
200
):
super
(
MultiProcessMapDataZMQ
,
self
)
.
__init__
(
ds
)
self
.
nr_proc
=
nr_proc
self
.
map_func
=
map_func
self
.
_buffer_size
=
buffer_size
self
.
_procs
=
[]
self
.
_reset_done
=
False
try
:
self
.
_size
=
ds
.
size
()
except
NotImplementedError
:
self
.
_size
=
-
1
def
reset_state
(
self
):
if
self
.
_reset_done
:
return
self
.
_reset_done
=
True
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
socket
.
set_hwm
(
self
.
_buffer_size
*
2
)
pipename
=
_get_pipe_name
(
'dataflow-map'
)
_bind_guard
(
self
.
socket
,
pipename
)
self
.
_proc_ids
=
[
u'{}'
.
format
(
k
)
.
encode
(
'utf-8'
)
for
k
in
range
(
self
.
nr_proc
)]
worker_hwm
=
int
(
self
.
_buffer_size
*
2
//
self
.
nr_proc
)
self
.
_procs
=
[
MultiProcessMapDataZMQ
.
_Worker
(
self
.
_proc_ids
[
k
],
self
.
map_func
,
pipename
,
worker_hwm
)
for
k
in
range
(
self
.
nr_proc
)]
self
.
ds
.
reset_state
()
self
.
_iter_ds
=
_repeat_iter
(
lambda
:
self
.
ds
.
get_data
())
self
.
_iter_worker
=
_repeat_iter
(
lambda
:
iter
(
self
.
_proc_ids
))
self
.
_guard
=
DataFlowReentrantGuard
()
self
.
_start_processes
()
self
.
_fill_buffer
()
import
atexit
atexit
.
register
(
lambda
x
:
x
.
__del__
(),
self
)
def
_fill_buffer
(
self
):
# Filling the buffer.
for
_
in
range
(
self
.
_buffer_size
):
self
.
_send
()
def
_start_processes
(
self
):
start_proc_mask_signal
(
self
.
_procs
)
def
_send
(
self
):
dp
=
next
(
self
.
_iter_ds
)
# round-robin assignment
worker
=
next
(
self
.
_iter_worker
)
msg
=
[
worker
,
dumps
(
dp
)]
self
.
socket
.
send_multipart
(
msg
,
copy
=
False
)
def
_recv
(
self
):
msg
=
self
.
socket
.
recv_multipart
(
copy
=
False
)
dp
=
loads
(
msg
[
1
]
.
bytes
)
return
dp
def
get_data
(
self
):
with
self
.
_guard
,
_zmq_catch_error
(
'MultiProcessMapData'
):
for
k
in
itertools
.
count
():
if
self
.
_size
>
0
and
k
>=
self
.
_size
:
break
yield
self
.
_recv
()
self
.
_send
()
def
__del__
(
self
):
if
not
self
.
_reset_done
:
return
if
not
self
.
context
.
closed
:
self
.
context
.
destroy
(
0
)
for
x
in
self
.
_procs
:
x
.
terminate
()
try
:
print
(
"MultiProcessMapData successfully cleaned-up."
)
except
:
pass
MultiProcessMapData
=
MultiProcessMapDataZMQ
# alias
if
__name__
==
'__main__'
:
from
.base
import
DataFlow
class
Naive
(
DataFlow
):
def
get_data
(
self
):
for
k
in
range
(
1000
):
yield
[
0
]
def
size
(
self
):
return
100
ds
=
Naive
()
ds
=
MultiProcessMapData
(
ds
,
3
,
lambda
x
:
[
x
[
0
]
+
1
])
ds
.
reset_state
()
for
k
in
ds
.
get_data
():
print
(
"Bang!"
,
k
)
print
(
"END!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment