Commit 50ff9036 authored by Yuxin Wu's avatar Yuxin Wu

Add MultiProcessMapAndBatchData

parent ba9d1793
......@@ -35,6 +35,11 @@ class TestDataSpeed(ProxyDataFlow):
super(TestDataSpeed, self).__init__(ds)
self.test_size = int(size)
self.warmup = int(warmup)
self._reset_called = False
def reset_state(self):
self._reset_called = True
super(TestDataSpeed, self).reset_state()
def __iter__(self):
""" Will run testing at the beginning, then produce data normally. """
......@@ -46,6 +51,7 @@ class TestDataSpeed(ProxyDataFlow):
"""
Start testing with a progress bar.
"""
if not self._reset_called:
self.ds.reset_state()
itr = self.ds.__iter__()
if self.warmup:
......@@ -91,6 +97,7 @@ class BatchData(ProxyDataFlow):
except NotImplementedError:
pass
self.batch_size = int(batch_size)
assert self.batch_size > 0
self.remainder = remainder
self.use_list = use_list
......@@ -111,10 +118,10 @@ class BatchData(ProxyDataFlow):
for data in self.ds:
holder.append(data)
if len(holder) == self.batch_size:
yield BatchData._aggregate_batch(holder, self.use_list)
yield BatchData.aggregate_batch(holder, self.use_list)
del holder[:]
if self.remainder and len(holder) > 0:
yield BatchData._aggregate_batch(holder, self.use_list)
yield BatchData.aggregate_batch(holder, self.use_list)
@staticmethod
def _batch_numpy(data_list):
......@@ -146,7 +153,18 @@ class BatchData(ProxyDataFlow):
pass
@staticmethod
def _aggregate_batch(data_holder, use_list=False):
def aggregate_batch(data_holder, use_list=False):
"""
Aggregate a list of datapoints to one batched datapoint.
Args:
data_holder (list[dp]): each dp is either a list or a dict.
use_list (bool): whether to batch data into a list or a numpy array.
Returns:
dp: either a list or a dict, depend on the inputs.
Each item is a batched version of the corresponding inputs.
"""
first_dp = data_holder[0]
if isinstance(first_dp, (list, tuple)):
result = []
......@@ -164,6 +182,8 @@ class BatchData(ProxyDataFlow):
result[key] = data_list
else:
result[key] = BatchData._batch_numpy(data_list)
else:
raise ValueError("Data point has to be list/tuple/dict. Got {}".format(type(first_dp)))
return result
......@@ -202,7 +222,7 @@ class BatchDataByShape(BatchData):
holder = self.holder[shp]
holder.append(dp)
if len(holder) == self.batch_size:
yield BatchData._aggregate_batch(holder)
yield BatchData.aggregate_batch(holder)
del holder[:]
......
......@@ -12,11 +12,12 @@ from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils.serialize import dumps, loads
from ..utils.develop import log_deprecated
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow
from .common import RepeatedData
from .common import RepeatedData, BatchData
from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error
__all__ = ['MultiThreadMapData',
'MultiProcessMapData', 'MultiProcessMapDataZMQ']
'MultiProcessMapData', 'MultiProcessMapDataZMQ',
'MultiProcessMapAndBatchData', 'MultiProcessMapAndBatchDataZMQ']
class _ParallelMapData(ProxyDataFlow):
......@@ -286,6 +287,9 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self._strict = strict
self._procs = []
def _create_worker(self, id, pipename, hwm):
return MultiProcessMapDataZMQ._Worker(id, self.map_func, pipename, hwm)
def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self)
_ParallelMapData.reset_state(self)
......@@ -299,8 +303,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.num_proc)]
worker_hwm = int(self._buffer_size * 2 // self.num_proc)
self._procs = [MultiProcessMapDataZMQ._Worker(
self._proc_ids[k], self.map_func, pipename, worker_hwm)
self._procs = [self._create_worker(self._proc_ids[k], pipename, worker_hwm)
for k in range(self.num_proc)]
self._start_processes()
......@@ -316,12 +319,120 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
return dp
def __iter__(self):
with self._guard, _zmq_catch_error('MultiProcessMapData'):
with self._guard, _zmq_catch_error(type(self).__name__):
for dp in super(MultiProcessMapDataZMQ, self).__iter__():
yield dp
MultiProcessMapData = MultiProcessMapDataZMQ # alias
class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
"""
Similar to :class:`MultiProcessMapDataZMQ`, except that this DataFlow
also does batching in parallel in the worker processes.
Therefore it can be helpful if you wish to hide the latency of batching.
When `nr_proc==1`, the behavior of this class is identical to
`BatchData(MapData(ds, map_func), batch_size)`.
When `nr_proc>1`, the datapoints may be grouped in arbitrary order,
or grouped with datapoints from a different pass of the given dataflow.
"""
class _Dispatcher(mp.Process):
def __init__(self, ds, pipename, hwm):
super(MultiProcessMapAndBatchDataZMQ._Dispatcher, self).__init__()
self.ds = RepeatedData(ds, -1)
self.pipename = pipename
self.hwm = hwm
def run(self):
enable_death_signal()
ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH)
socket.set_hwm(self.hwm)
socket.bind(self.pipename)
self.ds.reset_state()
for dp in self.ds:
socket.send(dumps(dp), copy=False)
class _Worker(mp.Process):
def __init__(self, identity, map_func, input_pipe, result_pipe, hwm, batch_size):
super(MultiProcessMapAndBatchDataZMQ._Worker, self).__init__()
self.identity = identity
self.map_func = map_func
self.input_pipe = input_pipe
self.result_pipe = result_pipe
self.hwm = hwm
self.batch_size = batch_size
def run(self):
enable_death_signal(_warn=self.identity == b'0')
ctx = zmq.Context()
socket = ctx.socket(zmq.PULL)
socket.setsockopt(zmq.IDENTITY, self.identity)
socket.set_hwm(self.hwm)
socket.connect(self.input_pipe)
out_socket = ctx.socket(zmq.PUSH)
out_socket.set_hwm(max(self.hwm // self.batch_size, 5))
out_socket.connect(self.result_pipe)
batch = []
while True:
dp = loads(socket.recv(copy=False))
dp = self.map_func(dp)
if dp is not None:
batch.append(dp)
if len(batch) == self.batch_size:
dp = BatchData.aggregate_batch(batch)
out_socket.send(dumps(dp), copy=False)
del batch[:]
def __init__(self, ds, num_proc, map_func, batch_size, buffer_size=1024):
"""
Args:
ds (DataFlow): the dataflow to map
num_proc(int): number of threads to use
map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint.
batch_size (int): batch size
buffer_size (int): number of datapoints in the buffer
"""
super(MultiProcessMapAndBatchDataZMQ, self).__init__()
self.ds = ds
self.num_proc = num_proc
self.map_func = map_func
self.buffer_size = buffer_size
self.batch_size = batch_size
assert self.batch_size < buffer_size
def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self)
self._guard = DataFlowReentrantGuard()
job_pipe = _get_pipe_name("dataflow_MaB_job")
result_pipe = _get_pipe_name("dataflow_MaB_result")
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
self.socket.set_hwm(self.buffer_size * 2 // self.batch_size)
_bind_guard(self.socket, result_pipe)
dispatcher = MultiProcessMapAndBatchDataZMQ._Dispatcher(self.ds, job_pipe, self.buffer_size)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.num_proc)]
worker_hwm = int(self.buffer_size * 2 // self.num_proc)
self._procs = [MultiProcessMapAndBatchDataZMQ._Worker(
self._proc_ids[k], self.map_func, job_pipe, result_pipe, worker_hwm, self.batch_size)
for k in range(self.num_proc)]
self._procs.append(dispatcher)
self._start_processes()
def __iter__(self):
with self._guard, _zmq_catch_error(type(self).__name__):
while True:
yield loads(self.socket.recv(copy=False))
def _pool_map(data):
......@@ -414,6 +525,11 @@ class MultiProcessMapDataComponentSharedArray(DataFlow):
yield dp
# alias
MultiProcessMapData = MultiProcessMapDataZMQ
MultiProcessMapAndBatchData = MultiProcessMapAndBatchDataZMQ
if __name__ == '__main__':
import time
......
......@@ -33,8 +33,12 @@ def build_or_reuse_placeholder(tensor_spec):
assert "Placeholder" in tensor.op.type, "Tensor {} exists but is not a placeholder!".format(name)
assert tensor_spec.is_compatible_with(tensor), \
"Tensor {} exists but is not compatible with the signature!".format(tensor)
if tensor.shape == tensor_spec.shape:
# It might be desirable to use a placeholder of a different shape in some tower
# (e.g., a less specific shape)
return tensor
except KeyError:
pass
with tfv1.name_scope(None): # clear any name scope it might get called in
ret = tfv1.placeholder(
tensor_spec.dtype, shape=tensor_spec.shape, name=tensor_spec.name)
......
......@@ -454,14 +454,21 @@ class TFDatasetInput(FeedfreeInput):
def __init__(self, dataset):
"""
Args:
dataset (tf.data.Dataset):
dataset (tf.data.Dataset or DataFlow): if a DataFlow, the dataflow
has to be infinite.
"""
if not isinstance(dataset, tf.data.Dataset):
raise ValueError("TFDatasetInput takes a tf.data.Dataset! Got {}".format(dataset))
if isinstance(dataset, tf.data.Dataset):
self._dataset = dataset
self._dataflow = None
elif isinstance(dataset, DataFlow):
self._dataset = None
self._dataflow = dataset
else:
raise ValueError("TFDatasetInput takes a tf.data.Dataset or DataFlow! Got {}".format(dataset))
def _setup(self, input_signature):
self._spec = input_signature
if self._dataset is not None:
types = self._dataset.output_types
spec_types = tuple([k.dtype for k in input_signature])
assert len(types) == len(spec_types), \
......@@ -470,6 +477,7 @@ class TFDatasetInput(FeedfreeInput):
assert types == spec_types, \
"Data types of dataset and input signature don't match! {} != {}".format(
str(types), str(spec_types))
shapes = self._dataset.output_shapes
spec_shapes = [k.shape for k in input_signature]
for idx, (s1, s2) in enumerate(zip(shapes, spec_shapes)):
......@@ -477,6 +485,9 @@ class TFDatasetInput(FeedfreeInput):
assert s2.is_compatible_with(s1), \
"Input signature '{}' has incompatible shape with dataset! {} vs {}".format(
input_signature[idx].name, s2, s1)
else:
self._dataset = TFDatasetInput.dataflow_to_dataset(self._dataflow, [x.dtype for x in input_signature])
self._iterator = self._dataset.make_initializable_iterator()
self._init_op = self._iterator.initializer
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment