Commit f7993410 authored by Yuxin Wu's avatar Yuxin Wu

update docs and some rename

parent 3465e1a5
......@@ -90,13 +90,13 @@ Alternatively, you can use multi-threaded preprocessing like this:
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors)
ds1 = ThreadedMapData(
ds1 = MultiThreadMapData(
ds0, nr_thread=25,
map_func=lambda dp: [augmentor.augment(dp[0]), dp[1]], buffer_size=1000)
# ds1 = PrefetchDataZMQ(ds1, nr_proc=1)
ds = BatchData(ds1, 256)
```
`ThreadedMapData` launches a thread pool to fetch data and apply the mapping function on **a single
`MultiThreadMapData` launches a thread pool to fetch data and apply the mapping function on **a single
instance of** `ds0`. This is done by an intermediate buffer of size 1000 to hide the mapping latency.
To reduce the effect of GIL to your main training thread, you want to uncomment the line so that everything above it (including all the
threads) happen in an independent process.
......@@ -115,7 +115,7 @@ If you identify this as a bottleneck, you can also use:
ds0 = dataset.ILSVRC12Files('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors)
ds1 = ThreadedMapData(
ds1 = MultiThreadMapData(
ds0, nr_thread=25,
map_func=lambda dp:
[augmentor.augment(cv2.imread(dp[0], cv2.IMREAD_COLOR)), dp[1]],
......@@ -220,7 +220,7 @@ launch the base DataFlow in only **one process**, and only parallelize the trans
with another `PrefetchDataZMQ`
(Nesting two `PrefetchDataZMQ`, however, will result in a different behavior.
These differences are explained in the API documentation in more details.).
Similar to what we did earlier, you can use `ThreadedMapData` to parallelize as well.
Similar to what we did earlier, you can use `MultiThreadMapData` to parallelize as well.
Let me summarize what this DataFlow does:
......
......@@ -12,7 +12,7 @@ from abc import abstractmethod
from tensorpack import imgaug, dataset, ModelDesc, InputDesc
from tensorpack.dataflow import (
AugmentImageComponent, PrefetchDataZMQ,
BatchData, ThreadedMapData)
BatchData, MultiThreadMapData)
from tensorpack.predict import PredictConfig, SimpleDatasetPredictor
from tensorpack.utils.stats import RatioCounter
from tensorpack.models import regularize_cost
......@@ -106,7 +106,7 @@ def get_imagenet_dataflow(
im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = aug.augment(im)
return im, cls
ds = ThreadedMapData(ds, cpu, mapf, buffer_size=2000, strict=True)
ds = MultiThreadMapData(ds, cpu, mapf, buffer_size=2000, strict=True)
ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1)
return ds
......
......@@ -121,10 +121,14 @@ class PrefetchData(ProxyDataFlow):
process by a Python :class:`multiprocessing.Queue`.
Note:
1. The underlying dataflow worker will be forked multiple times when ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be different.
(e.g. you are likely to see duplicated datapoints at the beginning)
1. An iterator cannot run faster automatically -- what's happenning is
that the underlying dataflow will be forked ``nr_proc`` times.
As a result, we have the following guarantee on the dataflow correctness:
a. When ``nr_proc=1``, the dataflow produces the same data as ``ds`` in the same order.
b. When ``nr_proc>1``, the dataflow produces the same distribution
of data as ``ds`` if each sample from ``ds`` is i.i.d. (e.g. fully shuffled).
You probably only want to use it for training.
2. This is significantly slower than :class:`PrefetchDataZMQ` when data is large.
3. When nesting like this: ``PrefetchDataZMQ(PrefetchData(df, nr_proc=a), nr_proc=b)``.
A total of ``a`` instances of ``df`` worker processes will be created.
......@@ -161,6 +165,10 @@ class PrefetchData(ProxyDataFlow):
self.nr_prefetch = nr_prefetch
self._guard = DataFlowReentrantGuard()
if nr_proc > 1:
logger.info("[PrefetchData] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
self.queue = mp.Queue(self.nr_prefetch)
self.procs = [PrefetchData._Worker(self.ds, self.queue)
for _ in range(self.nr_proc)]
......@@ -251,6 +259,9 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
self._guard = DataFlowReentrantGuard()
self._reset_done = False
if nr_proc > 1:
logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
def _recv(self):
return loads(self.socket.recv(copy=False).bytes)
......
......@@ -18,7 +18,7 @@ from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyC
from .base import Trainer
from .utility import LeastLoadedDeviceSetter, override_to_local_variable
__all__ = ['MultiGPUTrainerBase', 'LeastLoadedDeviceSetter',
__all__ = ['MultiGPUTrainerBase',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment