Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0fa90c41
Commit
0fa90c41
authored
Nov 27, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ParallelMap: use infinite iterator when strict=False
parent
0a5739fa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
47 deletions
+58
-47
tensorpack/dataflow/parallel.py
tensorpack/dataflow/parallel.py
+19
-11
tensorpack/dataflow/parallel_map.py
tensorpack/dataflow/parallel_map.py
+39
-34
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+0
-2
No files found.
tensorpack/dataflow/parallel.py
View file @
0fa90c41
...
...
@@ -95,21 +95,15 @@ class _MultiProcessZMQDataFlow(DataFlow):
def
reset_state
(
self
):
"""
All forked dataflows
ar
e reset **once and only once** in spawned processes.
Nothing more can be done when calling this method
.
All forked dataflows
should only b
e reset **once and only once** in spawned processes.
Subclasses should call this method with super
.
"""
if
self
.
_reset_done
:
return
assert
not
self
.
_reset_done
,
"reset_state() was called twice! This violates the API of DataFlow!"
self
.
_reset_done
=
True
# __del__ not guaranteed to get called at exit
atexit
.
register
(
del_weakref
,
weakref
.
ref
(
self
))
self
.
_reset_once
()
# build processes
def
_reset_once
(
self
):
pass
def
_start_processes
(
self
):
start_proc_mask_signal
(
self
.
_procs
)
...
...
@@ -315,7 +309,8 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
break
yield
self
.
_recv
()
def
_reset_once
(
self
):
def
reset_state
(
self
):
super
(
PrefetchDataZMQ
,
self
)
.
reset_state
()
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
socket
.
set_hwm
(
self
.
_hwm
)
...
...
@@ -400,7 +395,7 @@ class MultiThreadPrefetchData(DataFlow):
th
.
start
()
def
__len__
(
self
):
return
self
.
threads
[
0
]
.
__len__
()
return
self
.
threads
[
0
]
.
df
.
__len__
()
def
__iter__
(
self
):
while
True
:
...
...
@@ -463,3 +458,16 @@ plasma = None
# from ..utils.develop import create_dummy_class
# PlasmaPutData = create_dummy_class('PlasmaPutData', 'pyarrow') # noqa
# PlasmaGetData = create_dummy_class('PlasmaGetData', 'pyarrow') # noqa
if
__name__
==
'__main__'
:
import
time
from
.raw
import
DataFromGenerator
from
.common
import
FixedSizeData
x
=
DataFromGenerator
(
itertools
.
count
())
x
=
FixedSizeData
(
x
,
100
)
x
=
PrefetchDataZMQ
(
x
,
2
)
x
.
reset_state
()
for
idx
,
dp
in
enumerate
(
x
):
print
(
dp
)
time
.
sleep
(
0.1
)
tensorpack/dataflow/parallel_map.py
View file @
0fa90c41
...
...
@@ -9,6 +9,7 @@ from six.moves import queue
import
zmq
from
.base
import
DataFlow
,
ProxyDataFlow
,
DataFlowReentrantGuard
from
.common
import
RepeatedData
from
..utils.concurrency
import
StoppableThread
,
enable_death_signal
from
..utils
import
logger
from
..utils.serialize
import
loads
,
dumps
...
...
@@ -23,11 +24,18 @@ __all__ = ['ThreadedMapData', 'MultiThreadMapData',
class
_ParallelMapData
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
buffer_size
):
def
__init__
(
self
,
ds
,
buffer_size
,
strict
=
False
):
if
not
strict
:
ds
=
RepeatedData
(
ds
,
-
1
)
super
(
_ParallelMapData
,
self
)
.
__init__
(
ds
)
assert
buffer_size
>
0
,
buffer_size
self
.
_buffer_size
=
buffer_size
self
.
_buffer_occupancy
=
0
# actual #elements in buffer
self
.
_buffer_occupancy
=
0
# actual #elements in buffer, only useful in strict mode
self
.
_strict
=
strict
def
reset_state
(
self
):
super
(
_ParallelMapData
,
self
)
.
reset_state
()
self
.
_iter
=
self
.
ds
.
__iter__
()
def
_recv
(
self
):
pass
...
...
@@ -50,7 +58,8 @@ class _ParallelMapData(ProxyDataFlow):
self
.
_send
(
dp
)
except
StopIteration
:
logger
.
error
(
"[{}] buffer_size cannot be larger than the size of the DataFlow!"
.
format
(
type
(
self
)
.
__name__
))
"[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True!"
.
format
(
type
(
self
)
.
__name__
))
raise
self
.
_buffer_occupancy
+=
cnt
...
...
@@ -61,13 +70,6 @@ class _ParallelMapData(ProxyDataFlow):
if
ret
is
not
None
:
yield
ret
self
.
_iter
=
self
.
ds
.
__iter__
()
# refresh
for
_
in
range
(
self
.
_buffer_size
):
self
.
_send
(
next
(
self
.
_iter
))
ret
=
self
.
_recv
()
if
ret
is
not
None
:
yield
ret
def
get_data_strict
(
self
):
self
.
_fill_buffer
()
for
dp
in
self
.
_iter
:
...
...
@@ -83,6 +85,14 @@ class _ParallelMapData(ProxyDataFlow):
self
.
_fill_buffer
()
yield
dp
def
__iter__
(
self
):
if
self
.
_strict
:
for
dp
in
self
.
get_data_strict
():
yield
dp
else
:
for
dp
in
self
.
get_data_non_strict
():
yield
dp
class
MultiThreadMapData
(
_ParallelMapData
):
"""
...
...
@@ -141,7 +151,7 @@ class MultiThreadMapData(_ParallelMapData):
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
super
(
MultiThreadMapData
,
self
)
.
__init__
(
ds
,
buffer_size
)
super
(
MultiThreadMapData
,
self
)
.
__init__
(
ds
,
buffer_size
,
strict
)
self
.
_strict
=
strict
self
.
nr_thread
=
nr_thread
...
...
@@ -165,7 +175,6 @@ class MultiThreadMapData(_ParallelMapData):
for
t
in
self
.
_threads
:
t
.
start
()
self
.
_iter
=
self
.
ds
.
__iter__
()
self
.
_guard
=
DataFlowReentrantGuard
()
# Call once at the beginning, to ensure inq+outq has a total of buffer_size elements
...
...
@@ -179,12 +188,8 @@ class MultiThreadMapData(_ParallelMapData):
def
__iter__
(
self
):
with
self
.
_guard
:
if
self
.
_strict
:
for
dp
in
self
.
get_data_strict
():
yield
dp
else
:
for
dp
in
self
.
get_data_non_strict
():
yield
dp
for
dp
in
super
(
MultiThreadMapData
,
self
)
.
__iter__
():
yield
dp
def
__del__
(
self
):
if
self
.
_evt
is
not
None
:
...
...
@@ -245,7 +250,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
"""
_ParallelMapData
.
__init__
(
self
,
ds
,
buffer_size
)
_ParallelMapData
.
__init__
(
self
,
ds
,
buffer_size
,
strict
)
_MultiProcessZMQDataFlow
.
__init__
(
self
)
self
.
nr_proc
=
nr_proc
self
.
map_func
=
map_func
...
...
@@ -253,7 +258,10 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self
.
_procs
=
[]
self
.
_guard
=
DataFlowReentrantGuard
()
def
_reset_once
(
self
):
def
reset_state
(
self
):
_MultiProcessZMQDataFlow
.
reset_state
(
self
)
_ParallelMapData
.
reset_state
(
self
)
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
DEALER
)
self
.
socket
.
set_hwm
(
self
.
_buffer_size
*
2
)
...
...
@@ -266,15 +274,9 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self
.
_proc_ids
[
k
],
self
.
map_func
,
pipename
,
worker_hwm
)
for
k
in
range
(
self
.
nr_proc
)]
self
.
ds
.
reset_state
()
self
.
_iter
=
self
.
ds
.
__iter__
()
self
.
_start_processes
()
self
.
_fill_buffer
()
# pre-fill the bufer
def
reset_state
(
self
):
_MultiProcessZMQDataFlow
.
reset_state
(
self
)
def
_send
(
self
,
dp
):
msg
=
[
b
""
,
dumps
(
dp
)]
self
.
socket
.
send_multipart
(
msg
,
copy
=
False
)
...
...
@@ -286,12 +288,8 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
def
__iter__
(
self
):
with
self
.
_guard
,
_zmq_catch_error
(
'MultiProcessMapData'
):
if
self
.
_strict
:
for
dp
in
self
.
get_data_strict
():
yield
dp
else
:
for
dp
in
self
.
get_data_non_strict
():
yield
dp
for
dp
in
super
(
MultiProcessMapDataZMQ
,
self
)
.
__iter__
():
yield
dp
MultiProcessMapData
=
MultiProcessMapDataZMQ
# alias
...
...
@@ -388,6 +386,8 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
if
__name__
==
'__main__'
:
import
time
class
Zero
(
DataFlow
):
def
__init__
(
self
,
size
):
self
.
_size
=
size
...
...
@@ -399,8 +399,13 @@ if __name__ == '__main__':
def
__len__
(
self
):
return
self
.
_size
ds
=
Zero
(
300
)
ds
=
MultiProcessMapData
(
ds
,
3
,
lambda
x
:
[
x
[
0
]
+
1
],
strict
=
True
)
def
f
(
x
):
if
x
[
0
]
<
10
:
time
.
sleep
(
1
)
return
x
ds
=
Zero
(
100
)
ds
=
MultiThreadMapData
(
ds
,
50
,
f
,
buffer_size
=
50
,
strict
=
False
)
ds
.
reset_state
()
for
k
in
ds
:
print
(
"Bang!"
,
k
)
...
...
tensorpack/input_source/input_source.py
View file @
0fa90c41
...
...
@@ -187,8 +187,6 @@ class EnqueueThread(ShareSessionThread):
class
QueueInput
(
FeedfreeInput
):
""" Enqueue datapoints from a DataFlow to a TF queue.
And the model receives dequeued tensors.
Calling :meth:`reset_state()` will clear the queue and reset the dataflow.
"""
def
__init__
(
self
,
ds
,
queue
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment