Commit 2e7dd9c7 authored by Yuxin Wu's avatar Yuxin Wu

Add a third note in prefetch about forking

parent d6f0c57a
......@@ -16,7 +16,7 @@ from six.moves import range
from ..utils import logger
from ..utils.utils import get_tqdm_kwargs
from ..utils.develop import deprecated
from ..dataflow.base import DataFlow, DataFlowTerminated
from ..dataflow.base import DataFlow
from ..graph_builder.input_source_base import InputSource
from ..graph_builder.input_source import (
......@@ -118,8 +118,8 @@ class InferenceRunnerBase(Callback):
try:
for _ in tqdm.trange(self._size, **get_tqdm_kwargs()):
self._hooked_sess.run(fetches=[])
except (StopIteration, DataFlowTerminated,
tf.errors.CancelledError, tf.errors.OutOfRangeError):
except (StopIteration, tf.errors.CancelledError,
tf.errors.OutOfRangeError):
logger.error(
"[InferenceRunner] input stopped before reaching its size()! " + msg)
raise
......
......@@ -12,6 +12,11 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
class DataFlowTerminated(BaseException):
"""
An exception indicating that the DataFlow is unable to produce any more data:
calling :meth:`get_data` will not give a valid iterator any more.
In most DataFlow this will not be raised.
"""
pass
......
......@@ -50,8 +50,13 @@ class PrefetchData(ProxyDataFlow):
Note:
1. This is significantly slower than :class:`PrefetchDataZMQ` when data is large.
2. When nesting like this: ``PrefetchDataZMQ(PrefetchData(df, a), b)``.
2. When nesting like this: ``PrefetchDataZMQ(PrefetchData(df, nr_proc=a), nr_proc=b)``.
A total of ``a`` instances of ``df`` worker processes will be created.
This is different from the behavior of :class`PrefetchDataZMQ`
3. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be wrong.
(e.g. you are likely to see duplicated datapoints at the beginning)
"""
def __init__(self, ds, nr_prefetch, nr_proc=1):
"""
......@@ -115,6 +120,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
1. Once :meth:`reset_state` is called, this dataflow becomes not fork-safe.
2. When nesting like this: ``PrefetchDataZMQ(PrefetchDataZMQ(df, a), b)``.
A total of ``a * b`` instances of ``df`` worker processes will be created.
3. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be wrong.
(e.g. you are likely to see duplicated datapoints at the beginning)
"""
def __init__(self, ds, nr_proc=1, hwm=50):
"""
......@@ -234,6 +243,7 @@ class ThreadedMapData(ProxyDataFlow):
self.inq = inq
self.outq = outq
self.func = map_func
self.daemon = True
def run(self):
while not self.stopped():
......@@ -251,7 +261,8 @@ class ThreadedMapData(ProxyDataFlow):
buffer_size (int): number of datapoints in the buffer
"""
super(ThreadedMapData, self).__init__(ds)
self.infinite_ds = RepeatedData(ds, -1)
self._iter_ds = RepeatedData(ds, -1)
self.nr_thread = nr_thread
self.buffer_size = buffer_size
self.map_func = map_func
......@@ -271,15 +282,19 @@ class ThreadedMapData(ProxyDataFlow):
t.start()
# fill the buffer
self._itr = self.infinite_ds.get_data()
self._itr = self._iter_ds.get_data()
self._fill_buffer()
def _fill_buffer(self):
n = self.buffer_size - self._in_queue.qsize() - self._out_queue.qsize()
if n <= 0:
return
for _ in range(n):
self._in_queue.put(next(self._itr))
try:
for _ in range(n):
self._in_queue.put(next(self._itr))
except StopIteration:
logger.error("[ThreadedMapData] buffer_size cannot be larger than the size of the DataFlow!")
raise
def get_data(self):
self._fill_buffer()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment