Commit f7993410 authored by Yuxin Wu's avatar Yuxin Wu

update docs and some rename

parent 3465e1a5
...@@ -90,13 +90,13 @@ Alternatively, you can use multi-threaded preprocessing like this: ...@@ -90,13 +90,13 @@ Alternatively, you can use multi-threaded preprocessing like this:
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True) ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors) augmentor = AugmentorList(lots_of_augmentors)
ds1 = ThreadedMapData( ds1 = MultiThreadMapData(
ds0, nr_thread=25, ds0, nr_thread=25,
map_func=lambda dp: [augmentor.augment(dp[0]), dp[1]], buffer_size=1000) map_func=lambda dp: [augmentor.augment(dp[0]), dp[1]], buffer_size=1000)
# ds1 = PrefetchDataZMQ(ds1, nr_proc=1) # ds1 = PrefetchDataZMQ(ds1, nr_proc=1)
ds = BatchData(ds1, 256) ds = BatchData(ds1, 256)
``` ```
`ThreadedMapData` launches a thread pool to fetch data and apply the mapping function on **a single `MultiThreadMapData` launches a thread pool to fetch data and apply the mapping function on **a single
instance of** `ds0`. This is done by an intermediate buffer of size 1000 to hide the mapping latency. instance of** `ds0`. This is done by an intermediate buffer of size 1000 to hide the mapping latency.
To reduce the effect of GIL to your main training thread, you want to uncomment the line so that everything above it (including all the To reduce the effect of GIL to your main training thread, you want to uncomment the line so that everything above it (including all the
threads) happen in an independent process. threads) happen in an independent process.
...@@ -115,7 +115,7 @@ If you identify this as a bottleneck, you can also use: ...@@ -115,7 +115,7 @@ If you identify this as a bottleneck, you can also use:
ds0 = dataset.ILSVRC12Files('/path/to/ILSVRC12', 'train', shuffle=True) ds0 = dataset.ILSVRC12Files('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors) augmentor = AugmentorList(lots_of_augmentors)
ds1 = ThreadedMapData( ds1 = MultiThreadMapData(
ds0, nr_thread=25, ds0, nr_thread=25,
map_func=lambda dp: map_func=lambda dp:
[augmentor.augment(cv2.imread(dp[0], cv2.IMREAD_COLOR)), dp[1]], [augmentor.augment(cv2.imread(dp[0], cv2.IMREAD_COLOR)), dp[1]],
...@@ -220,7 +220,7 @@ launch the base DataFlow in only **one process**, and only parallelize the trans ...@@ -220,7 +220,7 @@ launch the base DataFlow in only **one process**, and only parallelize the trans
with another `PrefetchDataZMQ` with another `PrefetchDataZMQ`
(Nesting two `PrefetchDataZMQ`, however, will result in a different behavior. (Nesting two `PrefetchDataZMQ`, however, will result in a different behavior.
These differences are explained in the API documentation in more details.). These differences are explained in the API documentation in more details.).
Similar to what we did earlier, you can use `ThreadedMapData` to parallelize as well. Similar to what we did earlier, you can use `MultiThreadMapData` to parallelize as well.
Let me summarize what this DataFlow does: Let me summarize what this DataFlow does:
......
...@@ -12,7 +12,7 @@ from abc import abstractmethod ...@@ -12,7 +12,7 @@ from abc import abstractmethod
from tensorpack import imgaug, dataset, ModelDesc, InputDesc from tensorpack import imgaug, dataset, ModelDesc, InputDesc
from tensorpack.dataflow import ( from tensorpack.dataflow import (
AugmentImageComponent, PrefetchDataZMQ, AugmentImageComponent, PrefetchDataZMQ,
BatchData, ThreadedMapData) BatchData, MultiThreadMapData)
from tensorpack.predict import PredictConfig, SimpleDatasetPredictor from tensorpack.predict import PredictConfig, SimpleDatasetPredictor
from tensorpack.utils.stats import RatioCounter from tensorpack.utils.stats import RatioCounter
from tensorpack.models import regularize_cost from tensorpack.models import regularize_cost
...@@ -106,7 +106,7 @@ def get_imagenet_dataflow( ...@@ -106,7 +106,7 @@ def get_imagenet_dataflow(
im = cv2.imread(fname, cv2.IMREAD_COLOR) im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = aug.augment(im) im = aug.augment(im)
return im, cls return im, cls
ds = ThreadedMapData(ds, cpu, mapf, buffer_size=2000, strict=True) ds = MultiThreadMapData(ds, cpu, mapf, buffer_size=2000, strict=True)
ds = BatchData(ds, batch_size, remainder=True) ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1) ds = PrefetchDataZMQ(ds, 1)
return ds return ds
......
...@@ -121,10 +121,14 @@ class PrefetchData(ProxyDataFlow): ...@@ -121,10 +121,14 @@ class PrefetchData(ProxyDataFlow):
process by a Python :class:`multiprocessing.Queue`. process by a Python :class:`multiprocessing.Queue`.
Note: Note:
1. The underlying dataflow worker will be forked multiple times when ``nr_proc>1``. 1. An iterator cannot run faster automatically -- what's happenning is
As a result, unless the underlying dataflow is fully shuffled, the data distribution that the underlying dataflow will be forked ``nr_proc`` times.
produced by this dataflow will be different. As a result, we have the following guarantee on the dataflow correctness:
(e.g. you are likely to see duplicated datapoints at the beginning)
a. When ``nr_proc=1``, the dataflow produces the same data as ``ds`` in the same order.
b. When ``nr_proc>1``, the dataflow produces the same distribution
of data as ``ds`` if each sample from ``ds`` is i.i.d. (e.g. fully shuffled).
You probably only want to use it for training.
2. This is significantly slower than :class:`PrefetchDataZMQ` when data is large. 2. This is significantly slower than :class:`PrefetchDataZMQ` when data is large.
3. When nesting like this: ``PrefetchDataZMQ(PrefetchData(df, nr_proc=a), nr_proc=b)``. 3. When nesting like this: ``PrefetchDataZMQ(PrefetchData(df, nr_proc=a), nr_proc=b)``.
A total of ``a`` instances of ``df`` worker processes will be created. A total of ``a`` instances of ``df`` worker processes will be created.
...@@ -161,6 +165,10 @@ class PrefetchData(ProxyDataFlow): ...@@ -161,6 +165,10 @@ class PrefetchData(ProxyDataFlow):
self.nr_prefetch = nr_prefetch self.nr_prefetch = nr_prefetch
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
if nr_proc > 1:
logger.info("[PrefetchData] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
self.queue = mp.Queue(self.nr_prefetch) self.queue = mp.Queue(self.nr_prefetch)
self.procs = [PrefetchData._Worker(self.ds, self.queue) self.procs = [PrefetchData._Worker(self.ds, self.queue)
for _ in range(self.nr_proc)] for _ in range(self.nr_proc)]
...@@ -251,6 +259,9 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -251,6 +259,9 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
self._reset_done = False self._reset_done = False
if nr_proc > 1:
logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
def _recv(self): def _recv(self):
return loads(self.socket.recv(copy=False).bytes) return loads(self.socket.recv(copy=False).bytes)
......
...@@ -18,7 +18,7 @@ from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyC ...@@ -18,7 +18,7 @@ from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyC
from .base import Trainer from .base import Trainer
from .utility import LeastLoadedDeviceSetter, override_to_local_variable from .utility import LeastLoadedDeviceSetter, override_to_local_variable
__all__ = ['MultiGPUTrainerBase', 'LeastLoadedDeviceSetter', __all__ = ['MultiGPUTrainerBase',
'SyncMultiGPUTrainerReplicated', 'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer', 'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer', 'AsyncMultiGPUTrainer',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment