Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
50ff9036
Commit
50ff9036
authored
Aug 21, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add MultiProcessMapAndBatchData
parent
ba9d1793
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
187 additions
and
36 deletions
+187
-36
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+25
-5
tensorpack/dataflow/parallel_map.py
tensorpack/dataflow/parallel_map.py
+123
-7
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+9
-5
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+30
-19
No files found.
tensorpack/dataflow/common.py
View file @
50ff9036
...
...
@@ -35,6 +35,11 @@ class TestDataSpeed(ProxyDataFlow):
super
(
TestDataSpeed
,
self
)
.
__init__
(
ds
)
self
.
test_size
=
int
(
size
)
self
.
warmup
=
int
(
warmup
)
self
.
_reset_called
=
False
def
reset_state
(
self
):
self
.
_reset_called
=
True
super
(
TestDataSpeed
,
self
)
.
reset_state
()
def
__iter__
(
self
):
""" Will run testing at the beginning, then produce data normally. """
...
...
@@ -46,6 +51,7 @@ class TestDataSpeed(ProxyDataFlow):
"""
Start testing with a progress bar.
"""
if
not
self
.
_reset_called
:
self
.
ds
.
reset_state
()
itr
=
self
.
ds
.
__iter__
()
if
self
.
warmup
:
...
...
@@ -91,6 +97,7 @@ class BatchData(ProxyDataFlow):
except
NotImplementedError
:
pass
self
.
batch_size
=
int
(
batch_size
)
assert
self
.
batch_size
>
0
self
.
remainder
=
remainder
self
.
use_list
=
use_list
...
...
@@ -111,10 +118,10 @@ class BatchData(ProxyDataFlow):
for
data
in
self
.
ds
:
holder
.
append
(
data
)
if
len
(
holder
)
==
self
.
batch_size
:
yield
BatchData
.
_
aggregate_batch
(
holder
,
self
.
use_list
)
yield
BatchData
.
aggregate_batch
(
holder
,
self
.
use_list
)
del
holder
[:]
if
self
.
remainder
and
len
(
holder
)
>
0
:
yield
BatchData
.
_
aggregate_batch
(
holder
,
self
.
use_list
)
yield
BatchData
.
aggregate_batch
(
holder
,
self
.
use_list
)
@
staticmethod
def
_batch_numpy
(
data_list
):
...
...
@@ -146,7 +153,18 @@ class BatchData(ProxyDataFlow):
pass
@
staticmethod
def
_aggregate_batch
(
data_holder
,
use_list
=
False
):
def
aggregate_batch
(
data_holder
,
use_list
=
False
):
"""
Aggregate a list of datapoints to one batched datapoint.
Args:
data_holder (list[dp]): each dp is either a list or a dict.
use_list (bool): whether to batch data into a list or a numpy array.
Returns:
dp: either a list or a dict, depend on the inputs.
Each item is a batched version of the corresponding inputs.
"""
first_dp
=
data_holder
[
0
]
if
isinstance
(
first_dp
,
(
list
,
tuple
)):
result
=
[]
...
...
@@ -164,6 +182,8 @@ class BatchData(ProxyDataFlow):
result
[
key
]
=
data_list
else
:
result
[
key
]
=
BatchData
.
_batch_numpy
(
data_list
)
else
:
raise
ValueError
(
"Data point has to be list/tuple/dict. Got {}"
.
format
(
type
(
first_dp
)))
return
result
...
...
@@ -202,7 +222,7 @@ class BatchDataByShape(BatchData):
holder
=
self
.
holder
[
shp
]
holder
.
append
(
dp
)
if
len
(
holder
)
==
self
.
batch_size
:
yield
BatchData
.
_
aggregate_batch
(
holder
)
yield
BatchData
.
aggregate_batch
(
holder
)
del
holder
[:]
...
...
tensorpack/dataflow/parallel_map.py
View file @
50ff9036
...
...
@@ -12,11 +12,12 @@ from ..utils.concurrency import StoppableThread, enable_death_signal
from
..utils.serialize
import
dumps
,
loads
from
..utils.develop
import
log_deprecated
from
.base
import
DataFlow
,
DataFlowReentrantGuard
,
ProxyDataFlow
from
.common
import
RepeatedData
from
.common
import
RepeatedData
,
BatchData
from
.parallel
import
_bind_guard
,
_get_pipe_name
,
_MultiProcessZMQDataFlow
,
_repeat_iter
,
_zmq_catch_error
__all__
=
[
'MultiThreadMapData'
,
'MultiProcessMapData'
,
'MultiProcessMapDataZMQ'
]
'MultiProcessMapData'
,
'MultiProcessMapDataZMQ'
,
'MultiProcessMapAndBatchData'
,
'MultiProcessMapAndBatchDataZMQ'
]
class
_ParallelMapData
(
ProxyDataFlow
):
...
...
@@ -286,6 +287,9 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self
.
_strict
=
strict
self
.
_procs
=
[]
def
_create_worker
(
self
,
id
,
pipename
,
hwm
):
return
MultiProcessMapDataZMQ
.
_Worker
(
id
,
self
.
map_func
,
pipename
,
hwm
)
def
reset_state
(
self
):
_MultiProcessZMQDataFlow
.
reset_state
(
self
)
_ParallelMapData
.
reset_state
(
self
)
...
...
@@ -299,8 +303,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self
.
_proc_ids
=
[
u'{}'
.
format
(
k
)
.
encode
(
'utf-8'
)
for
k
in
range
(
self
.
num_proc
)]
worker_hwm
=
int
(
self
.
_buffer_size
*
2
//
self
.
num_proc
)
self
.
_procs
=
[
MultiProcessMapDataZMQ
.
_Worker
(
self
.
_proc_ids
[
k
],
self
.
map_func
,
pipename
,
worker_hwm
)
self
.
_procs
=
[
self
.
_create_worker
(
self
.
_proc_ids
[
k
],
pipename
,
worker_hwm
)
for
k
in
range
(
self
.
num_proc
)]
self
.
_start_processes
()
...
...
@@ -316,12 +319,120 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
return
dp
def
__iter__
(
self
):
with
self
.
_guard
,
_zmq_catch_error
(
'MultiProcessMapData'
):
with
self
.
_guard
,
_zmq_catch_error
(
type
(
self
)
.
__name__
):
for
dp
in
super
(
MultiProcessMapDataZMQ
,
self
)
.
__iter__
():
yield
dp
MultiProcessMapData
=
MultiProcessMapDataZMQ
# alias
class
MultiProcessMapAndBatchDataZMQ
(
_MultiProcessZMQDataFlow
):
"""
Similar to :class:`MultiProcessMapDataZMQ`, except that this DataFlow
also does batching in parallel in the worker processes.
Therefore it can be helpful if you wish to hide the latency of batching.
When `nr_proc==1`, the behavior of this class is identical to
`BatchData(MapData(ds, map_func), batch_size)`.
When `nr_proc>1`, the datapoints may be grouped in arbitrary order,
or grouped with datapoints from a different pass of the given dataflow.
"""
class
_Dispatcher
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
pipename
,
hwm
):
super
(
MultiProcessMapAndBatchDataZMQ
.
_Dispatcher
,
self
)
.
__init__
()
self
.
ds
=
RepeatedData
(
ds
,
-
1
)
self
.
pipename
=
pipename
self
.
hwm
=
hwm
def
run
(
self
):
enable_death_signal
()
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PUSH
)
socket
.
set_hwm
(
self
.
hwm
)
socket
.
bind
(
self
.
pipename
)
self
.
ds
.
reset_state
()
for
dp
in
self
.
ds
:
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
class
_Worker
(
mp
.
Process
):
def
__init__
(
self
,
identity
,
map_func
,
input_pipe
,
result_pipe
,
hwm
,
batch_size
):
super
(
MultiProcessMapAndBatchDataZMQ
.
_Worker
,
self
)
.
__init__
()
self
.
identity
=
identity
self
.
map_func
=
map_func
self
.
input_pipe
=
input_pipe
self
.
result_pipe
=
result_pipe
self
.
hwm
=
hwm
self
.
batch_size
=
batch_size
def
run
(
self
):
enable_death_signal
(
_warn
=
self
.
identity
==
b
'0'
)
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PULL
)
socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
socket
.
set_hwm
(
self
.
hwm
)
socket
.
connect
(
self
.
input_pipe
)
out_socket
=
ctx
.
socket
(
zmq
.
PUSH
)
out_socket
.
set_hwm
(
max
(
self
.
hwm
//
self
.
batch_size
,
5
))
out_socket
.
connect
(
self
.
result_pipe
)
batch
=
[]
while
True
:
dp
=
loads
(
socket
.
recv
(
copy
=
False
))
dp
=
self
.
map_func
(
dp
)
if
dp
is
not
None
:
batch
.
append
(
dp
)
if
len
(
batch
)
==
self
.
batch_size
:
dp
=
BatchData
.
aggregate_batch
(
batch
)
out_socket
.
send
(
dumps
(
dp
),
copy
=
False
)
del
batch
[:]
def
__init__
(
self
,
ds
,
num_proc
,
map_func
,
batch_size
,
buffer_size
=
1024
):
"""
Args:
ds (DataFlow): the dataflow to map
num_proc(int): number of threads to use
map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint.
batch_size (int): batch size
buffer_size (int): number of datapoints in the buffer
"""
super
(
MultiProcessMapAndBatchDataZMQ
,
self
)
.
__init__
()
self
.
ds
=
ds
self
.
num_proc
=
num_proc
self
.
map_func
=
map_func
self
.
buffer_size
=
buffer_size
self
.
batch_size
=
batch_size
assert
self
.
batch_size
<
buffer_size
def
reset_state
(
self
):
_MultiProcessZMQDataFlow
.
reset_state
(
self
)
self
.
_guard
=
DataFlowReentrantGuard
()
job_pipe
=
_get_pipe_name
(
"dataflow_MaB_job"
)
result_pipe
=
_get_pipe_name
(
"dataflow_MaB_result"
)
self
.
context
=
zmq
.
Context
()
self
.
socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
socket
.
set_hwm
(
self
.
buffer_size
*
2
//
self
.
batch_size
)
_bind_guard
(
self
.
socket
,
result_pipe
)
dispatcher
=
MultiProcessMapAndBatchDataZMQ
.
_Dispatcher
(
self
.
ds
,
job_pipe
,
self
.
buffer_size
)
self
.
_proc_ids
=
[
u'{}'
.
format
(
k
)
.
encode
(
'utf-8'
)
for
k
in
range
(
self
.
num_proc
)]
worker_hwm
=
int
(
self
.
buffer_size
*
2
//
self
.
num_proc
)
self
.
_procs
=
[
MultiProcessMapAndBatchDataZMQ
.
_Worker
(
self
.
_proc_ids
[
k
],
self
.
map_func
,
job_pipe
,
result_pipe
,
worker_hwm
,
self
.
batch_size
)
for
k
in
range
(
self
.
num_proc
)]
self
.
_procs
.
append
(
dispatcher
)
self
.
_start_processes
()
def
__iter__
(
self
):
with
self
.
_guard
,
_zmq_catch_error
(
type
(
self
)
.
__name__
):
while
True
:
yield
loads
(
self
.
socket
.
recv
(
copy
=
False
))
def
_pool_map
(
data
):
...
...
@@ -414,6 +525,11 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
yield
dp
# alias
MultiProcessMapData
=
MultiProcessMapDataZMQ
MultiProcessMapAndBatchData
=
MultiProcessMapAndBatchDataZMQ
if
__name__
==
'__main__'
:
import
time
...
...
tensorpack/graph_builder/model_desc.py
View file @
50ff9036
...
...
@@ -33,8 +33,12 @@ def build_or_reuse_placeholder(tensor_spec):
assert
"Placeholder"
in
tensor
.
op
.
type
,
"Tensor {} exists but is not a placeholder!"
.
format
(
name
)
assert
tensor_spec
.
is_compatible_with
(
tensor
),
\
"Tensor {} exists but is not compatible with the signature!"
.
format
(
tensor
)
if
tensor
.
shape
==
tensor_spec
.
shape
:
# It might be desirable to use a placeholder of a different shape in some tower
# (e.g., a less specific shape)
return
tensor
except
KeyError
:
pass
with
tfv1
.
name_scope
(
None
):
# clear any name scope it might get called in
ret
=
tfv1
.
placeholder
(
tensor_spec
.
dtype
,
shape
=
tensor_spec
.
shape
,
name
=
tensor_spec
.
name
)
...
...
tensorpack/input_source/input_source.py
View file @
50ff9036
...
...
@@ -454,14 +454,21 @@ class TFDatasetInput(FeedfreeInput):
def
__init__
(
self
,
dataset
):
"""
Args:
dataset (tf.data.Dataset):
dataset (tf.data.Dataset or DataFlow): if a DataFlow, the dataflow
has to be infinite.
"""
if
not
isinstance
(
dataset
,
tf
.
data
.
Dataset
):
raise
ValueError
(
"TFDatasetInput takes a tf.data.Dataset! Got {}"
.
format
(
dataset
))
if
isinstance
(
dataset
,
tf
.
data
.
Dataset
):
self
.
_dataset
=
dataset
self
.
_dataflow
=
None
elif
isinstance
(
dataset
,
DataFlow
):
self
.
_dataset
=
None
self
.
_dataflow
=
dataset
else
:
raise
ValueError
(
"TFDatasetInput takes a tf.data.Dataset or DataFlow! Got {}"
.
format
(
dataset
))
def
_setup
(
self
,
input_signature
):
self
.
_spec
=
input_signature
if
self
.
_dataset
is
not
None
:
types
=
self
.
_dataset
.
output_types
spec_types
=
tuple
([
k
.
dtype
for
k
in
input_signature
])
assert
len
(
types
)
==
len
(
spec_types
),
\
...
...
@@ -470,6 +477,7 @@ class TFDatasetInput(FeedfreeInput):
assert
types
==
spec_types
,
\
"Data types of dataset and input signature don't match! {} != {}"
.
format
(
str
(
types
),
str
(
spec_types
))
shapes
=
self
.
_dataset
.
output_shapes
spec_shapes
=
[
k
.
shape
for
k
in
input_signature
]
for
idx
,
(
s1
,
s2
)
in
enumerate
(
zip
(
shapes
,
spec_shapes
)):
...
...
@@ -477,6 +485,9 @@ class TFDatasetInput(FeedfreeInput):
assert
s2
.
is_compatible_with
(
s1
),
\
"Input signature '{}' has incompatible shape with dataset! {} vs {}"
.
format
(
input_signature
[
idx
]
.
name
,
s2
,
s1
)
else
:
self
.
_dataset
=
TFDatasetInput
.
dataflow_to_dataset
(
self
.
_dataflow
,
[
x
.
dtype
for
x
in
input_signature
])
self
.
_iterator
=
self
.
_dataset
.
make_initializable_iterator
()
self
.
_init_op
=
self
.
_iterator
.
initializer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment