Commit 2e7dd9c7 authored by Yuxin Wu's avatar Yuxin Wu

Add a third note in prefetch about forking

parent d6f0c57a
...@@ -16,7 +16,7 @@ from six.moves import range ...@@ -16,7 +16,7 @@ from six.moves import range
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm_kwargs from ..utils.utils import get_tqdm_kwargs
from ..utils.develop import deprecated from ..utils.develop import deprecated
from ..dataflow.base import DataFlow, DataFlowTerminated from ..dataflow.base import DataFlow
from ..graph_builder.input_source_base import InputSource from ..graph_builder.input_source_base import InputSource
from ..graph_builder.input_source import ( from ..graph_builder.input_source import (
...@@ -118,8 +118,8 @@ class InferenceRunnerBase(Callback): ...@@ -118,8 +118,8 @@ class InferenceRunnerBase(Callback):
try: try:
for _ in tqdm.trange(self._size, **get_tqdm_kwargs()): for _ in tqdm.trange(self._size, **get_tqdm_kwargs()):
self._hooked_sess.run(fetches=[]) self._hooked_sess.run(fetches=[])
except (StopIteration, DataFlowTerminated, except (StopIteration, tf.errors.CancelledError,
tf.errors.CancelledError, tf.errors.OutOfRangeError): tf.errors.OutOfRangeError):
logger.error( logger.error(
"[InferenceRunner] input stopped before reaching its size()! " + msg) "[InferenceRunner] input stopped before reaching its size()! " + msg)
raise raise
......
...@@ -12,6 +12,11 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated'] ...@@ -12,6 +12,11 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
class DataFlowTerminated(BaseException): class DataFlowTerminated(BaseException):
"""
An exception indicating that the DataFlow is unable to produce any more data:
calling :meth:`get_data` will not give a valid iterator any more.
In most DataFlow this will not be raised.
"""
pass pass
......
...@@ -50,8 +50,13 @@ class PrefetchData(ProxyDataFlow): ...@@ -50,8 +50,13 @@ class PrefetchData(ProxyDataFlow):
Note: Note:
1. This is significantly slower than :class:`PrefetchDataZMQ` when data is large. 1. This is significantly slower than :class:`PrefetchDataZMQ` when data is large.
2. When nesting like this: ``PrefetchDataZMQ(PrefetchData(df, a), b)``. 2. When nesting like this: ``PrefetchDataZMQ(PrefetchData(df, nr_proc=a), nr_proc=b)``.
A total of ``a`` instances of ``df`` worker processes will be created. A total of ``a`` instances of ``df`` worker processes will be created.
This is different from the behavior of :class`PrefetchDataZMQ`
3. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be wrong.
(e.g. you are likely to see duplicated datapoints at the beginning)
""" """
def __init__(self, ds, nr_prefetch, nr_proc=1): def __init__(self, ds, nr_prefetch, nr_proc=1):
""" """
...@@ -115,6 +120,10 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -115,6 +120,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
1. Once :meth:`reset_state` is called, this dataflow becomes not fork-safe. 1. Once :meth:`reset_state` is called, this dataflow becomes not fork-safe.
2. When nesting like this: ``PrefetchDataZMQ(PrefetchDataZMQ(df, a), b)``. 2. When nesting like this: ``PrefetchDataZMQ(PrefetchDataZMQ(df, a), b)``.
A total of ``a * b`` instances of ``df`` worker processes will be created. A total of ``a * b`` instances of ``df`` worker processes will be created.
3. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be wrong.
(e.g. you are likely to see duplicated datapoints at the beginning)
""" """
def __init__(self, ds, nr_proc=1, hwm=50): def __init__(self, ds, nr_proc=1, hwm=50):
""" """
...@@ -234,6 +243,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -234,6 +243,7 @@ class ThreadedMapData(ProxyDataFlow):
self.inq = inq self.inq = inq
self.outq = outq self.outq = outq
self.func = map_func self.func = map_func
self.daemon = True
def run(self): def run(self):
while not self.stopped(): while not self.stopped():
...@@ -251,7 +261,8 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -251,7 +261,8 @@ class ThreadedMapData(ProxyDataFlow):
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
""" """
super(ThreadedMapData, self).__init__(ds) super(ThreadedMapData, self).__init__(ds)
self.infinite_ds = RepeatedData(ds, -1)
self._iter_ds = RepeatedData(ds, -1)
self.nr_thread = nr_thread self.nr_thread = nr_thread
self.buffer_size = buffer_size self.buffer_size = buffer_size
self.map_func = map_func self.map_func = map_func
...@@ -271,15 +282,19 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -271,15 +282,19 @@ class ThreadedMapData(ProxyDataFlow):
t.start() t.start()
# fill the buffer # fill the buffer
self._itr = self.infinite_ds.get_data() self._itr = self._iter_ds.get_data()
self._fill_buffer() self._fill_buffer()
def _fill_buffer(self): def _fill_buffer(self):
n = self.buffer_size - self._in_queue.qsize() - self._out_queue.qsize() n = self.buffer_size - self._in_queue.qsize() - self._out_queue.qsize()
if n <= 0: if n <= 0:
return return
try:
for _ in range(n): for _ in range(n):
self._in_queue.put(next(self._itr)) self._in_queue.put(next(self._itr))
except StopIteration:
logger.error("[ThreadedMapData] buffer_size cannot be larger than the size of the DataFlow!")
raise
def get_data(self): def get_data(self):
self._fill_buffer() self._fill_buffer()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment