Commit dd2d9ffa authored by Yuxin Wu's avatar Yuxin Wu

Rename PrefetchData -> MultiProcessRunner

parent 0cecfbb6
......@@ -372,7 +372,6 @@ _DEPRECATED_NAMES = set([
# deprecated stuff:
'QueueInputTrainer',
'dump_dataflow_to_process_queue',
'PrefetchOnGPUs',
'DistributedTrainerReplicated',
'DistributedTrainerParameterServer',
'InputDesc',
......@@ -382,11 +381,14 @@ _DEPRECATED_NAMES = set([
'DumpTensor',
'DumpParamAsImage',
'get_nr_gpu',
'start_test', # TestDataSpeed
'ThreadedMapData',
'TrainingMonitor',
'PeakMemoryTracker',
'PrefetchData',
'MultiProcessPrefetchData',
'PrefetchDataZMQ',
'MultiThreadPrefetchData',
# deprecated or renamed symbolic code
'Deconv2D', 'psnr',
......
......@@ -3,11 +3,11 @@
### What is DataFlow
DataFlow is a library to build Python iterators for efficient data loading.
DataFlow is a pure-Python library to create iterators for efficient data loading.
**Definition**: A DataFlow is a idiomatic Python container object that has a `__iter__()` generator method,
which yields `datapoints` and optionally a `__len__()` method returning the size of the flow.
A datapoint is a **list** of Python objects which are called the `components` of a datapoint.
**Definition**: A DataFlow is a idiomatic Python iterator object that has a `__iter__()` method
which yields `datapoints`, and optionally a `__len__()` method returning the size of the DataFlow.
A datapoint is a **list or dict** of Python objects, each of which are called the `components` of a datapoint.
**Example**: to train on MNIST dataset, you may need a DataFlow with a `__iter__()` method
that yields datapoints (lists) of two components:
......@@ -21,12 +21,10 @@ You can simply use DataFlow as a data processing pipeline and plug it into any o
### Composition of DataFlow
One good thing about having a standard interface is to be able to provide
the greatest code reusability.
There are a lot of existing DataFlow utilities in tensorpack, which you can use to compose
DataFlow with complex data pipeline. A common pipeline usually
would __read from disk (or other sources), apply transformations, group into batches,
prefetch data__, etc. A simple example is as the following:
one DataFlow with complex data pipeline. A common pipeline usually
would __read from disk (or other sources), apply transformations (possibly in parallel), group into batches,
prefetch data__, etc, and all __run in parallel__. A simple example is as the following:
````python
# a DataFlow you implement to produce [tensor1, tensor2, ..] lists from whatever sources:
......@@ -36,17 +34,17 @@ df = AugmentImageComponent(df, [imgaug.Resize((225, 225))])
# group data into batches of size 128
df = BatchData(df, 128)
# start 3 processes to run the dataflow in parallel
df = PrefetchDataZMQ(df, 3)
df = MultiProcessRunnerZMQ(df, 3)
````
You can find more complicated DataFlow in the [ImageNet training script](../examples/ImageNetModels/imagenet_utils.py)
with all the data preprocessing.
### Work with Your Data
Unless you are working with standard data types (image folders, LMDB, etc),
you would usually want to write the source DataFlow (`MyDataFlow` in the above example) for your data format.
We do not make any assumptions about your data format.
You would usually want to write the source DataFlow (`MyDataFlow` in the above example) for your own data format.
See [another tutorial](extend/dataflow.html) for simple instructions on writing a DataFlow.
Once you have the source reader, all the [existing
DataFlows](../modules/dataflow.html) are ready for you to build up the rest of the data pipeline.
Once you have the source reader, all the [built-in
DataFlows](../modules/dataflow.html) are ready for you to assemble the rest of the data pipeline.
### Why DataFlow
......@@ -62,16 +60,16 @@ Nevertheless, tensorpack supports data loading with native TF operators / TF dat
### Use DataFlow in Your Own Code
Normally, tensorpack `InputSource` interface runs the DataFlow during training.
However, DataFlow can also be used without other tensorpack components.
If you need to run the DataFlow by yourself, call `reset_state()` first to initialize it,
When training with tensorpack, typically it is the `InputSource` interface that runs the DataFlow.
However, DataFlow can be used without other tensorpack components.
To run a DataFlow by yourself, call `reset_state()` first to initialize it,
and then use the generator however you like:
```python
df = SomeDataFlow()
df.reset_state()
for dp in df:
# dp is now a list. do whatever
# dp is now a list. do whatever
```
Read the [API documentation](../../modules/dataflow.html#tensorpack.dataflow.DataFlw)
......
......@@ -16,7 +16,7 @@ then apply complicated preprocessing to it.
We aim to reach a speed of, roughly **1k~3k images per second**, to keep GPUs busy.
Some things to know before reading:
1. For smaller datasets (e.g. several GBs of images with lightweight preprocessing), a simple reader plus some multiprocess prefetch should usually work well enough.
1. For smaller datasets (e.g. several GBs of images with lightweight preprocessing), a simple reader plus some multiprocess runner should usually work well enough.
Therefore you don't have to understand this tutorial in depth unless you really find your data being the bottleneck.
This tutorial could be a bit complicated for people new to system architectures, but you do need these to be able to run fast enough on ImageNet-scale dataset.
2. Having a fast Python generator **alone** may or may not improve your overall training speed.
......@@ -64,7 +64,7 @@ On a good filesystem you probably can already observe good speed here (e.g. 5 it
because we are doing heavy random read on the filesystem (regardless of whether `shuffle` is True).
Image decoding in `cv2.imread` could also be a bottleneck at this early stage.
### Parallel Prefetch
### Parallel Runner
We will now add the cheapest pre-processing now to get an ndarray in the end instead of a list
(because training will need ndarray eventually):
......@@ -84,15 +84,15 @@ Now it's time to add threads or processes:
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True)
ds1 = AugmentImageComponent(ds0, lots_of_augmentors)
ds = PrefetchDataZMQ(ds1, nr_proc=25)
ds = MultiProcessRunnerZMQ(ds1, num_proc=25)
ds = BatchData(ds, 256)
```
Here we fork 25 processes to run `ds1`, and collect their output through ZMQ IPC protocol,
which is faster than `multiprocessing.Queue`. You can also apply prefetch after batch, of course.
which is faster than `multiprocessing.Queue`. You can also apply parallel runner after batching, of course.
### Parallel Map
The above DataFlow might be fast, but since it forks the ImageNet reader (`ds0`),
it's **not a good idea to use it for validation** (for reasons mentioned at top. More details at the [documentation](../modules/dataflow.html#tensorpack.dataflow.PrefetchDataZMQ)).
it's **not a good idea to use it for validation** (for reasons mentioned at top. More details at the [documentation](../modules/dataflow.html#tensorpack.dataflow.MultiProcessRunnerZMQ)).
Alternatively, you can use multi-threaded preprocessing like this:
```eval_rst
......@@ -102,9 +102,9 @@ Alternatively, you can use multi-threaded preprocessing like this:
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors)
ds1 = MultiThreadMapData(
ds0, nr_thread=25,
ds0, num_thread=25,
map_func=lambda dp: [augmentor.augment(dp[0]), dp[1]], buffer_size=1000)
# ds1 = PrefetchDataZMQ(ds1, nr_proc=1)
# ds1 = MultiProcessRunnerZMQ(ds1, num_proc=1)
ds = BatchData(ds1, 256)
```
`MultiThreadMapData` launches a thread pool to fetch data and apply the mapping function on **a single
......@@ -127,11 +127,11 @@ If you identify this as a bottleneck, you can also use:
ds0 = dataset.ILSVRC12Files('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors)
ds1 = MultiThreadMapData(
ds0, nr_thread=25,
ds0, num_thread=25,
map_func=lambda dp:
[augmentor.augment(cv2.imread(dp[0], cv2.IMREAD_COLOR)), dp[1]],
buffer_size=1000)
ds1 = PrefetchDataZMQ(ds1, nr_proc=1)
ds1 = MultiProcessRunnerZMQ(ds1, num_proc=1)
ds = BatchData(ds1, 256)
```
......@@ -159,15 +159,15 @@ class BinaryILSVRC12(dataset.ILSVRC12Files):
jpeg = np.asarray(bytearray(jpeg), dtype='uint8')
yield [jpeg, label]
ds0 = BinaryILSVRC12('/path/to/ILSVRC/', 'train')
ds1 = PrefetchDataZMQ(ds0, nr_proc=1)
ds1 = MultiProcessRunnerZMQ(ds0, num_proc=1)
LMDBSerializer.save(ds1, '/path/to/ILSVRC-train.lmdb')
```
The above script builds a DataFlow which produces jpeg-encoded ImageNet data.
We store the jpeg string as a numpy array because the function `cv2.imdecode` later expect this format.
Please note we can only use 1 prefetch process to speed up. If `nr_proc>1`, `ds1` will take data
Please note we can only use 1 runner process to speed up. If `num_proc>1`, `ds1` will take data
from several forks of `ds0`, then neither the content nor the order of `ds1` will be the same as `ds0`.
See [documentation](../modules/dataflow.html#tensorpack.dataflow.PrefetchDataZMQ)
about caveats of `PrefetchDataZMQ`.
See [documentation](../modules/dataflow.html#tensorpack.dataflow.MultiProcessRunnerZMQ)
about caveats of `MultiProcessRunnerZMQ`.
It will generate a database file of 140G. We load the DataFlow back by reading this LMDB file sequentially:
```
......@@ -193,7 +193,7 @@ the added line above maintains a buffer of datapoints and shuffle them once a wh
It will not affect the model as long as the buffer is large enough,
but it can also consume much memory if too large.
### Augmentations & Parallel Prefetch
### Augmentations & Parallel Runner
Then we add necessary transformations:
```eval_rst
......@@ -218,24 +218,24 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like
ds = LMDBSerializer.load(db, shuffle=False)
ds = LocallyShuffleData(ds, 50000)
ds = PrefetchData(ds, 5000, 1)
ds = MultiProcessRunner(ds, 5000, 1)
ds = MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0)
ds = AugmentImageComponent(ds, lots_of_augmentors)
ds = PrefetchDataZMQ(ds, 25)
ds = MultiProcessRunnerZMQ(ds, 25)
ds = BatchData(ds, 256)
```
Since we are reading the database sequentially, having multiple forked instances of the
base LMDB reader will result in biased data distribution. Therefore we use `PrefetchData` to
base LMDB reader will result in biased data distribution. Therefore we use `MultiProcessRunner` to
launch the base DataFlow in only **one process**, and only parallelize the transformations
with another `PrefetchDataZMQ`
(Nesting two `PrefetchDataZMQ`, however, will result in a different behavior.
with another `MultiProcessRunnerZMQ`
(Nesting two `MultiProcessRunnerZMQ`, however, will result in a different behavior.
These differences are explained in the API documentation in more details.).
Similar to what we did earlier, you can use `MultiThreadMapData` to parallelize as well.
Let me summarize what this DataFlow does:
1. One process reads LMDB file, shuffle them in a buffer and put them into a `multiprocessing.Queue` (used by `PrefetchData`).
1. One process reads LMDB file, shuffle them in a buffer and put them into a `multiprocessing.Queue` (used by `MultiProcessRunner`).
2. 25 processes take items from the queue, decode and process them into [image, label] pairs, and
send them through ZMQ IPC pipe.
3. The main process takes data from the pipe, makes batches.
......
......@@ -82,7 +82,7 @@ def get_data(path, isTrain, stat_file):
ds = MapDataComponent(ds, lambda x: (x - mean) / std)
ds = TIMITBatch(ds, BATCH)
if isTrain:
ds = PrefetchDataZMQ(ds, 1)
ds = MultiProcessRunnerZMQ(ds, 1)
return ds
......
......@@ -32,7 +32,7 @@ def get_data():
]
data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128)
data_train = PrefetchData(data_train, 5, 5)
data_train = MultiProcessRunner(data_train, 5, 5)
augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors)
......
......@@ -148,7 +148,7 @@ def get_config():
]
data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128)
data_train = PrefetchDataZMQ(data_train, 5)
data_train = MultiProcessRunnerZMQ(data_train, 5)
augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors)
......
......@@ -225,7 +225,7 @@ def get_data():
ds = ThetaImages(ds)
ds = RepeatedData(ds, 50) # just pretend this dataset is bigger
# this pre-computation is pretty heavy
ds = PrefetchDataZMQ(ds, min(20, multiprocessing.cpu_count()))
ds = MultiProcessRunnerZMQ(ds, min(20, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH)
return ds
......
......@@ -9,7 +9,7 @@ from tabulate import tabulate
from termcolor import colored
from tensorpack.dataflow import (
DataFromList, MapData, MapDataComponent, MultiProcessMapDataZMQ, MultiThreadMapData,
DataFromList, MapData, MapDataComponent, MultiProcessMapData, MultiThreadMapData,
TestDataSpeed, imgaug)
from tensorpack.utils import logger
from tensorpack.utils.argtools import log_once, memoized
......@@ -392,7 +392,7 @@ def get_train_dataflow():
# MPI does not like fork()
else:
buffer_size = cfg.DATA.NUM_WORKERS * 20
ds = MultiProcessMapDataZMQ(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
ds = MultiProcessMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
else:
ds = MapData(ds, preprocess)
return ds
......
......@@ -177,7 +177,7 @@ def get_data(datadir, isTrain=True):
names = ['trainA', 'trainB'] if isTrain else ['testA', 'testB']
df = get_image_pairs(*[os.path.join(datadir, n) for n in names])
df = BatchData(df, BATCH if isTrain else TEST_BATCH)
df = PrefetchDataZMQ(df, 2 if isTrain else 1)
df = MultiProcessRunnerZMQ(df, 2 if isTrain else 1)
return df
......
......@@ -115,7 +115,7 @@ def get_data():
ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = AugmentImageComponent(ds, get_augmentors())
ds = BatchData(ds, args.batch)
ds = PrefetchDataZMQ(ds, 5)
ds = MultiProcessRunnerZMQ(ds, 5)
return ds
......
......@@ -186,7 +186,7 @@ def get_celebA_data(datadir, styleA, styleB=None):
imgaug.Resize(64)]
df = AugmentImageComponents(df, augs, (0, 1))
df = BatchData(df, BATCH)
df = PrefetchDataZMQ(df, 3)
df = MultiProcessRunnerZMQ(df, 3)
return df
......
......@@ -173,7 +173,7 @@ def get_data():
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH)
ds = PrefetchData(ds, 100, 1)
ds = MultiProcessRunner(ds, 100, 1)
return ds
......
......@@ -233,7 +233,7 @@ def get_data(name):
]
ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = BatchDataByShape(ds, 8, idx=0)
ds = PrefetchDataZMQ(ds, 1)
ds = MultiProcessRunnerZMQ(ds, 1)
else:
ds = BatchData(ds, 1)
return ds
......
......@@ -11,7 +11,9 @@ import tensorflow as tf
import tqdm
from tensorpack import ModelDesc
from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug
from tensorpack.dataflow import (
AugmentImageComponent, BatchData, MultiThreadMapData,
MultiProcessRunnerZMQ, dataset, imgaug)
from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.models import regularize_cost, l2_regularizer
from tensorpack.predict import FeedfreePredictor, PredictConfig
......@@ -88,7 +90,7 @@ def get_imagenet_dataflow(
ds = AugmentImageComponent(ds, augmentors, copy=False)
if parallel < 16:
logger.warn("DataFlow may become the bottleneck when too few processes are used.")
ds = PrefetchDataZMQ(ds, parallel)
ds = MultiProcessRunnerZMQ(ds, parallel)
ds = BatchData(ds, batch_size, remainder=False)
else:
ds = dataset.ILSVRC12Files(datadir, name, shuffle=False)
......@@ -101,7 +103,7 @@ def get_imagenet_dataflow(
return im, cls
ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1)
ds = MultiProcessRunnerZMQ(ds, 1)
return ds
......
......@@ -133,7 +133,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain:
ds = PrefetchData(ds, 3, 2)
ds = MultiProcessRunner(ds, 3, 2)
return ds
......
......@@ -68,7 +68,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors, copy=False)
if isTrain:
ds = PrefetchDataZMQ(ds, min(25, multiprocessing.cpu_count()))
ds = MultiProcessRunnerZMQ(ds, min(25, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
return ds
......
......@@ -254,7 +254,7 @@ def get_data(file_name):
imgaug.Flip(horiz=True)]
ds = AugmentImageComponent(ds, augmentors, index=0, copy=True)
ds = MapData(ds, lambda x: [cv2.resize(x[0], (32, 32), interpolation=cv2.INTER_CUBIC), x[0]])
ds = PrefetchDataZMQ(ds, 3)
ds = MultiProcessRunnerZMQ(ds, 3)
ds = BatchData(ds, BATCH_SIZE)
return ds
......
......@@ -103,7 +103,7 @@ def get_data(train_or_test, cifar_classnum):
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain:
ds = PrefetchDataZMQ(ds, 5)
ds = MultiProcessRunnerZMQ(ds, 5)
return ds
......
......@@ -78,7 +78,7 @@ def get_data():
]
data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128)
data_train = PrefetchData(data_train, 5, 5)
data_train = MultiProcessRunner(data_train, 5, 5)
augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors)
......
......@@ -37,7 +37,7 @@ def get_data(subset):
# something that yields [[SHAPE, SHAPE, CHANNELS], [1]]
ds = FakeData([[SHAPE, SHAPE, CHANNELS], [1]], 1000, random=False,
dtype=['float32', 'uint8'], domain=[(0, 255), (0, 10)])
ds = PrefetchDataZMQ(ds, 2)
ds = MultiProcessRunnerZMQ(ds, 2)
ds = BatchData(ds, BATCH_SIZE)
return ds
......
......@@ -64,29 +64,25 @@ class DataFlow(object):
@abstractmethod
def __iter__(self):
"""
* A dataflow is an iterable. The :meth:`__iter__` method should yield a list each time.
Each element in the list should be either a number or a numpy array.
For now, tensorpack also **partially** supports dict instead of list.
* A dataflow is an iterable. The :meth:`__iter__` method should yield a list or dict each time.
Note that dict is **partially** supported at the moment: certain dataflow does not support dict.
* The :meth:`__iter__` method can be either finite (will stop iteration) or infinite
(will not stop iteration). For a finite dataflow, :meth:`__iter__` can be called
again after the previous call returned.
again immediately after the previous call returned.
* For many dataflow, the :meth:`__iter__` method is non-reentrant, which means for an dataflow
instance ``df``, :meth:`df.__iter__` cannot be called before the previous
:meth:`df.__iter__` call has finished (iteration has stopped).
When it is non-reentrant, :meth:`df.__iter__` should throw an exception if
When a dataflow is non-reentrant, :meth:`df.__iter__` should throw an exception if
called before the previous call has finished.
For such non-reentrant dataflows, if you need to use the same dataflow in two places,
you need to create two dataflow instances.
Yields:
list: The datapoint, i.e. list of components.
list/dict: The datapoint, i.e. list/dict of components.
"""
def get_data(self):
return self.__iter__()
def __len__(self):
"""
* A dataflow can optionally implement :meth:`__len__`. If not implemented, it will
......@@ -95,7 +91,7 @@ class DataFlow(object):
* It returns an integer representing the size of the dataflow.
The return value **may not be accurate or meaningful** at all.
When saying the length is "accurate", it means that
:meth:`__iter__` will always yield this many of datapoints.
:meth:`__iter__` will always yield this many of datapoints before it stops iteration.
* There could be many reasons why :meth:`__len__` is inaccurate.
For example, some dataflow has dynamic size, if it throws away datapoints on the fly.
......@@ -103,8 +99,9 @@ class DataFlow(object):
the dataset, due to parallelism and buffering.
In this case it does not make sense to stop the iteration anywhere.
* Due to the above reasons, the length is only a rough guidance. Inside
tensorpack it's only used in these places:
* Due to the above reasons, the length is only a rough guidance.
And it's up to the user how to interpret it.
Inside tensorpack it's only used in these places:
+ A default ``steps_per_epoch`` in training, but you probably want to customize
it yourself, especially when using data-parallel trainer.
......@@ -121,9 +118,6 @@ class DataFlow(object):
"""
raise NotImplementedError()
def size(self):
return self.__len__()
def reset_state(self):
"""
* The caller must guarantee that :meth:`reset_state` should be called **once and only once**
......@@ -134,21 +128,28 @@ class DataFlow(object):
e.g., initialize random number generators (RNG), create worker processes.
Because it's very common to use RNG in data processing,
developers of dataflow can also subclass :class:`RNGDataFlow` to have easier access to an RNG.
developers of dataflow can also subclass :class:`RNGDataFlow` to have easier access to
a properly-initialized RNG.
* A dataflow is not fork-safe after :meth:`reset_state` is called (because this will violate the guarantee).
A few number of dataflow is not fork-safe anytime, which will be mentioned in the docs.
There are a few other dataflows that are not fork-safe anytime, which will be mentioned in the docs.
* You should take the responsibility and follow the above guarantee if you're the caller of a dataflow yourself
(either when you're using dataflow outside of tensorpack, or if you're writing a wrapper dataflow).
* Tensorpack's built-in forking dataflows (:class:`MultiProcessPrefetchData`, :class:`MultiProcessMapData`, etc)
* Tensorpack's built-in forking dataflows (:class:`MultiProcessRunner`, :class:`MultiProcessMapData`, etc)
and other component that uses dataflows (:class:`InputSource`)
already take care of the responsibility of calling this method.
* You should take the responsibility and follow the above guarantee if you're the caller of a dataflow yourself
(either if you're using dtaflow outside of tensorpack,
or if you're writing a wrapper dataflow).
"""
pass
# These are the old (overly verbose) names for the methods:
def get_data(self):
return self.__iter__()
def size(self):
return self.__len__()
class RNGDataFlow(DataFlow):
""" A DataFlow with RNG"""
......@@ -156,7 +157,7 @@ class RNGDataFlow(DataFlow):
rng = None
"""
``self.rng`` is a ``np.random.RandomState`` instance that is initialized
correctly in ``RNGDataFlow.reset_state()``.
correctly (with different seeds in each process) in ``RNGDataFlow.reset_state()``.
"""
def reset_state(self):
......
......@@ -14,6 +14,7 @@ from termcolor import colored
from ..utils import logger
from ..utils.utils import get_rng, get_tqdm, get_tqdm_kwargs
from ..utils.develop import log_deprecated
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow, RNGDataFlow
__all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
......@@ -23,7 +24,7 @@ __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'Fixed
class TestDataSpeed(ProxyDataFlow):
""" Test the speed of some DataFlow """
""" Test the speed of a DataFlow """
def __init__(self, ds, size=5000, warmup=0):
"""
Args:
......@@ -175,7 +176,7 @@ class BatchDataByShape(BatchData):
Note:
It is implemented by a dict{shape -> datapoints}.
Datapoints of uncommon shapes may never be enough to form a batch and
Therefore, datapoints of uncommon shapes may never be enough to form a batch and
never get generated.
"""
def __init__(self, ds, batch_size, idx):
......@@ -184,7 +185,7 @@ class BatchDataByShape(BatchData):
ds (DataFlow): input DataFlow. ``dp[idx]`` has to be an :class:`np.ndarray`.
batch_size (int): batch size
idx (int): ``dp[idx].shape`` will be used to group datapoints.
Other components are assumed to have the same shape.
Other components are assumed to be batch-able.
"""
super(BatchDataByShape, self).__init__(ds, batch_size, remainder=False)
self.idx = idx
......@@ -267,13 +268,13 @@ class MapData(ProxyDataFlow):
Note:
1. Please make sure func doesn't modify its arguments in place,
unless you're certain it's safe.
2. If you discard some datapoints, ``len(ds)`` will be incorrect.
2. If you discard some datapoints, ``len(MapData(ds))`` will be incorrect.
Example:
.. code-block:: none
ds = Mnist('train)
ds = Mnist('train') # each datapoint is [img, label]
ds = MapData(ds, lambda dp: [dp[0] * 255, dp[1]])
"""
......@@ -302,14 +303,14 @@ class MapDataComponent(MapData):
1. This dataflow itself doesn't modify the datapoints.
But please make sure func doesn't modify its arguments in place,
unless you're certain it's safe.
2. If you discard some datapoints, ``len(ds)`` will be incorrect.
2. If you discard some datapoints, ``len(MapDataComponent(ds, ..))`` will be incorrect.
Example:
.. code-block:: none
ds = Mnist('train)
ds = MapDataComponent(ds, lambda img: img * 255, 0)
ds = Mnist('train') # each datapoint is [img, label]
ds = MapDataComponent(ds, lambda img: img * 255, 0) # map the 0th component
"""
def __init__(self, ds, func, index=0):
"""
......@@ -340,32 +341,32 @@ class RepeatedData(ProxyDataFlow):
dp1, dp2, .... dpn, dp1, dp2, ....dpn
"""
def __init__(self, ds, nr):
def __init__(self, ds, num):
"""
Args:
ds (DataFlow): input DataFlow
nr (int): number of times to repeat ds.
num (int): number of times to repeat ds.
Set to -1 to repeat ``ds`` infinite times.
"""
self.nr = nr
self.num = num
super(RepeatedData, self).__init__(ds)
def __len__(self):
"""
Raises:
:class:`ValueError` when nr == -1.
:class:`ValueError` when num == -1.
"""
if self.nr == -1:
if self.num == -1:
raise NotImplementedError("__len__() is unavailable for infinite dataflow")
return len(self.ds) * self.nr
return len(self.ds) * self.num
def __iter__(self):
if self.nr == -1:
if self.num == -1:
while True:
for dp in self.ds:
yield dp
else:
for _ in range(self.nr):
for _ in range(self.num):
for dp in self.ds:
yield dp
......@@ -376,22 +377,22 @@ class RepeatedDataPoint(ProxyDataFlow):
dp1, dp1, ..., dp1, dp2, ..., dp2, ...
"""
def __init__(self, ds, nr):
def __init__(self, ds, num):
"""
Args:
ds (DataFlow): input DataFlow
nr (int): number of times to repeat each datapoint.
num (int): number of times to repeat each datapoint.
"""
self.nr = int(nr)
assert self.nr >= 1, self.nr
self.num = int(num)
assert self.num >= 1, self.num
super(RepeatedDataPoint, self).__init__(ds)
def __len__(self):
return len(self.ds) * self.nr
return len(self.ds) * self.num
def __iter__(self):
for dp in self.ds:
for _ in range(self.nr):
for _ in range(self.num):
yield dp
......@@ -474,7 +475,7 @@ class RandomMixData(RNGDataFlow):
class ConcatData(DataFlow):
"""
Concatenate several DataFlow.
Produce datapoints from each DataFlow and go to the next when one
Produce datapoints from each DataFlow and start the next when one
DataFlow is exhausted.
"""
......@@ -501,8 +502,8 @@ class ConcatData(DataFlow):
class JoinData(DataFlow):
"""
Join the components from each DataFlow. See below for its behavior.
Dataflow that produces lists and dataflow that produces dicts
cannot be joined.
Note that you can't join a DataFlow that produces lists with one that produces dicts.
Example:
......@@ -524,7 +525,7 @@ class JoinData(DataFlow):
When these dataflows have different sizes, JoinData will stop when any
of them is exhausted.
The list could contain the same DataFlow instance more than once,
but note that `__iter__` will then also be called many times.
but note that in that case `__iter__` will then also be called many times.
"""
self.df_lists = df_lists
......@@ -568,7 +569,7 @@ def SelectComponent(ds, idxs):
Args:
ds (DataFlow): input DataFlow.
idxs (list[int]): a list of component indices.
idxs (list[int] or list[str]): a list of component indices/keys.
Example:
......@@ -583,13 +584,13 @@ def SelectComponent(ds, idxs):
class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
""" Buffer the datapoints from a given dataflow, and shuffle them before producing them.
This can be used as an alternative when a complete random read is too expensive
This can be used as an alternative when a complete random shuffle is too expensive
or impossible for the data source.
This dataflow has the following behavior:
1. It takes datapoints from the given dataflow `ds` to an internal buffer of fixed size.
Each datapoint is duplicated for `nr_reuse` times.
Each datapoint is duplicated for `num_reuse` times.
2. Once the buffer is full, this dataflow starts to yield data from the beginning of the buffer,
and new datapoints will be added to the end of the buffer. This is like a FIFO queue.
3. The internal buffer is shuffled after every `shuffle_interval` datapoints that come from `ds`.
......@@ -601,24 +602,28 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
because it does not make sense to stop the iteration anywhere.
"""
def __init__(self, ds, buffer_size, nr_reuse=1, shuffle_interval=None):
def __init__(self, ds, buffer_size, num_reuse=1, shuffle_interval=None, nr_reuse=None):
"""
Args:
ds (DataFlow): input DataFlow.
buffer_size (int): size of the buffer.
nr_reuse (int): duplicate each datapoints several times into the buffer to improve
num_reuse (int): duplicate each datapoints several times into the buffer to improve
speed, but duplication may hurt your model.
shuffle_interval (int): shuffle the buffer after this many
datapoints were produced from the given dataflow. Frequent shuffle on large buffer
may affect speed, but infrequent shuffle may not provide enough randomness.
Defaults to buffer_size / 3
nr_reuse: deprecated name for num_reuse
"""
if nr_reuse is not None:
log_deprecated("LocallyShuffleData(nr_reuse=...)", "Renamed to 'num_reuse'.", "2020-01-01")
num_reuse = nr_reuse
ProxyDataFlow.__init__(self, ds)
self.q = deque(maxlen=buffer_size)
if shuffle_interval is None:
shuffle_interval = int(buffer_size // 3)
self.shuffle_interval = shuffle_interval
self.nr_reuse = nr_reuse
self.num_reuse = num_reuse
self._inf_ds = RepeatedData(ds, -1)
def reset_state(self):
......@@ -629,7 +634,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
self._inf_iter = iter(self._inf_ds)
def __len__(self):
return len(self.ds) * self.nr_reuse
return len(self.ds) * self.num_reuse
def __iter__(self):
with self._guard:
......@@ -638,7 +643,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
# fill queue
if self._iter_cnt == 0:
self.rng.shuffle(self.q)
for _ in range(self.nr_reuse):
for _ in range(self.num_reuse):
if self.q.maxlen == len(self.q):
yield self.q.popleft()
self.q.append(dp)
......@@ -646,7 +651,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
class CacheData(ProxyDataFlow):
"""
Cache the first pass of a DataFlow completely in memory,
Completely cache the first pass of a DataFlow in memory,
and produce from the cache thereafter.
NOTE: The user should not stop the iterator before it has reached the end.
......@@ -656,7 +661,7 @@ class CacheData(ProxyDataFlow):
"""
Args:
ds (DataFlow): input DataFlow.
shuffle (bool): whether to shuffle the datapoints before producing them.
shuffle (bool): whether to shuffle the cache before yielding from it.
"""
self.shuffle = shuffle
super(CacheData, self).__init__(ds)
......@@ -683,10 +688,11 @@ class CacheData(ProxyDataFlow):
class PrintData(ProxyDataFlow):
"""
Behave like an identity mapping, but print shape and range of the first few datapoints.
Behave like an identity proxy, but print shape and range of the first few datapoints.
Good for debugging.
Example:
To enable this debugging output, you should place it somewhere in your dataflow like
Place it somewhere in your dataflow like
.. code-block:: python
......
......@@ -166,7 +166,7 @@ class LMDBDataDecoder(MapData):
def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
"""
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf.
Read a Caffe-format LMDB file where each value contains a ``caffe.Datum`` protobuf.
Produces datapoints of the format: [HWC image, label].
Note that Caffe LMDB format is not efficient: it stores serialized raw
......@@ -175,9 +175,6 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`.
Returns:
a :class:`LMDBDataDecoder` instance.
Example:
.. code-block:: python
......
......@@ -92,7 +92,7 @@ class AugmentImageComponent(MapDataComponent):
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
index (int): the index of the image component to be augmented in the datapoint.
index (int or str): the index or key of the image component to be augmented in the datapoint.
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
......@@ -134,8 +134,8 @@ class AugmentImageCoordinates(MapData):
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
img_index (int): the index of the image component to be augmented.
coords_index (int): the index of the coordinate component to be augmented.
img_index (int or str): the index/key of the image component to be augmented.
coords_index (int or str): the index/key of the coordinate component to be augmented.
copy, catch_exceptions: same as in :class:`AugmentImageComponent`
"""
if isinstance(augmentors, AugmentorList):
......
......@@ -14,12 +14,14 @@ import zmq
from six.moves import queue, range
from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.concurrency import (
StoppableThread, enable_death_signal, ensure_proc_terminate, start_proc_mask_signal)
from ..utils.serialize import dumps, loads
from .base import DataFlow, DataFlowReentrantGuard, DataFlowTerminated, ProxyDataFlow
__all__ = ['PrefetchData', 'MultiProcessPrefetchData',
'MultiProcessRunner', 'MultiProcessRunnerZMQ', 'MultiThreadRunner',
'PrefetchDataZMQ', 'MultiThreadPrefetchData']
......@@ -35,7 +37,7 @@ def _bind_guard(sock, name):
except zmq.ZMQError:
logger.error(
"ZMQError in socket.bind('{}'). Perhaps you're \
using pipes on a non-local file system. See documentation of PrefetchDataZMQ \
using pipes on a non-local file system. See documentation of MultiProcessRunnerZMQ \
for more information.".format(name))
raise
......@@ -118,27 +120,27 @@ class _MultiProcessZMQDataFlow(DataFlow):
pass
class MultiProcessPrefetchData(ProxyDataFlow):
class MultiProcessRunner(ProxyDataFlow):
"""
Prefetch data from a DataFlow using Python multiprocessing utilities.
It will fork the process calling :meth:`__init__`, collect datapoints from `ds` in each
Running a DataFlow in >=1 processes using Python multiprocessing utilities.
It will fork the process that calls :meth:`__init__`, collect datapoints from `ds` in each
process by a Python :class:`multiprocessing.Queue`.
Note:
1. (Data integrity) An iterator cannot run faster automatically -- what's happening is
that the process will be forked ``nr_proc`` times.
There will be ``nr_proc`` dataflow running in parallel and **independently**.
that the process will be forked ``num_proc`` times.
There will be ``num_proc`` dataflow running in parallel and **independently**.
As a result, we have the following guarantee on the dataflow correctness:
a. When ``nr_proc=1``, this dataflow produces the same data as the
a. When ``num_proc=1``, this dataflow produces the same data as the
given dataflow in the same order.
b. When ``nr_proc>1``, if each sample from the given dataflow is i.i.d.,
b. When ``num_proc>1``, if each sample from the given dataflow is i.i.d.,
then this dataflow produces the **same distribution** of data as the given dataflow.
This implies that there will be duplication, reordering, etc.
You probably only want to use it for training.
For example, if your original dataflow contains no randomness and produces the same first datapoint,
then after parallel prefetching, the datapoint will be produced ``nr_proc`` times
then after parallel prefetching, the datapoint will be produced ``num_proc`` times
at the beginning.
Even when your original dataflow is fully shuffled, you still need to be aware of the
`Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_
......@@ -146,10 +148,11 @@ class MultiProcessPrefetchData(ProxyDataFlow):
To utilize parallelism with more strict data integrity, you can use
the parallel versions of :class:`MapData`: :class:`MultiThreadMapData`, :class:`MultiProcessMapData`.
2. This has more serialization overhead than :class:`PrefetchDataZMQ` when data is large.
3. You can nest like this: ``PrefetchDataZMQ(PrefetchData(df, nr_proc=a), nr_proc=b)``.
2. This has more serialization overhead than :class:`MultiProcessRunnerZMQ` when data is large.
3. You can nest like this: ``MultiProcessRunnerZMQ(MultiProcessRunner(df, num_proc=a), num_proc=b)``.
A total of ``a`` instances of ``df`` worker processes will be created.
4. fork happens in `__init__`. `reset_state()` is a no-op. The worker processes won't get called.
4. Fork happens in `__init__`. `reset_state()` is a no-op.
DataFlow in the worker processes will be reset at the time of fork.
5. This DataFlow does support windows. However, Windows requires more strict picklability on processes,
which means that some code that's forkable on Linux may not be forkable on Windows. If that happens you'll
need to re-organize some part of code that's not forkable.
......@@ -157,7 +160,7 @@ class MultiProcessPrefetchData(ProxyDataFlow):
class _Worker(mp.Process):
def __init__(self, ds, queue, idx):
super(MultiProcessPrefetchData._Worker, self).__init__()
super(MultiProcessRunner._Worker, self).__init__()
self.ds = ds
self.queue = queue
self.idx = idx
......@@ -170,33 +173,43 @@ class MultiProcessPrefetchData(ProxyDataFlow):
for dp in self.ds:
self.queue.put(dp)
def __init__(self, ds, nr_prefetch, nr_proc):
def __init__(self, ds, num_prefetch=None, num_proc=None, nr_prefetch=None, nr_proc=None):
"""
Args:
ds (DataFlow): input DataFlow.
nr_prefetch (int): size of the queue to hold prefetched datapoints.
nr_proc (int): number of processes to use.
num_prefetch (int): size of the queue to hold prefetched datapoints.
num_proc (int): number of processes to use.
nr_prefetch, nr_proc: deprecated argument names
"""
if nr_prefetch is not None:
log_deprecated("MultiProcessRunner(nr_prefetch)", "Renamed to 'num_prefetch'", "2020-01-01")
num_prefetch = nr_prefetch
if nr_proc is not None:
log_deprecated("MultiProcessRunner(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#the-spawn-and-forkserver-start-methods
if os.name == 'nt':
logger.warn("MultiProcessPrefetchData does support Windows. \
logger.warn("MultiProcessRunner does support Windows. \
However, Windows requires more strict picklability on processes, which may \
lead of failure on some of the code.")
super(MultiProcessPrefetchData, self).__init__(ds)
super(MultiProcessRunner, self).__init__(ds)
try:
self._size = len(ds)
except NotImplementedError:
self._size = -1
self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch
assert num_proc > 0, num_proc
assert num_prefetch > 0, num_prefetch
self.num_proc = num_proc
self.num_prefetch = num_prefetch
if nr_proc > 1:
logger.info("[MultiProcessPrefetchData] Will fork a dataflow more than one times. "
if num_proc > 1:
logger.info("[MultiProcessRunner] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
self.queue = mp.Queue(self.nr_prefetch)
self.procs = [MultiProcessPrefetchData._Worker(self.ds, self.queue, idx)
for idx in range(self.nr_proc)]
self.queue = mp.Queue(self.num_prefetch)
self.procs = [MultiProcessRunner._Worker(self.ds, self.queue, idx)
for idx in range(self.num_proc)]
ensure_proc_terminate(self.procs)
start_proc_mask_signal(self.procs)
......@@ -212,31 +225,28 @@ lead of failure on some of the code.")
pass
PrefetchData = MultiProcessPrefetchData
# TODO renamed to MultiProcessDataFlow{,ZMQ} if separated to a new project
class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
class MultiProcessRunnerZMQ(_MultiProcessZMQDataFlow):
"""
Prefetch data from a DataFlow using multiple processes, with ZeroMQ for communication.
Run a DataFlow in >=1 processes, with ZeroMQ for communication.
It will fork the calling process of :meth:`reset_state()`,
and collect datapoints from the given dataflow in each process by ZeroMQ IPC pipe.
This is typically faster than :class:`MultiProcessRunner`.
Note:
1. (Data integrity) An iterator cannot run faster automatically -- what's happening is
that the process will be forked ``nr_proc`` times.
There will be ``nr_proc`` dataflow running in parallel and **independently**.
that the process will be forked ``num_proc`` times.
There will be ``num_proc`` dataflow running in parallel and **independently**.
As a result, we have the following guarantee on the dataflow correctness:
a. When ``nr_proc=1``, this dataflow produces the same data as the
a. When ``num_proc=1``, this dataflow produces the same data as the
given dataflow in the same order.
b. When ``nr_proc>1``, if each sample from the given dataflow is i.i.d.,
b. When ``num_proc>1``, if each sample from the given dataflow is i.i.d.,
then this dataflow produces the **same distribution** of data as the given dataflow.
This implies that there will be duplication, reordering, etc.
You probably only want to use it for training.
For example, if your original dataflow contains no randomness and produces the same first datapoint,
then after parallel prefetching, the datapoint will be produced ``nr_proc`` times
then after parallel prefetching, the datapoint will be produced ``num_proc`` times
at the beginning.
Even when your original dataflow is fully shuffled, you still need to be aware of the
`Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_
......@@ -251,10 +261,10 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
it's better to fork before creating the session.
4. (Fork-safety) After the fork has happened, this dataflow becomes not fork-safe.
i.e., if you fork an already reset instance of this dataflow,
it won't be usable in the forked process. Therefore, do not nest two `PrefetchDataZMQ`.
it won't be usable in the forked process. Therefore, do not nest two `MultiProcessRunnerZMQ`.
5. (Thread-safety) ZMQ is not thread safe. Therefore, do not call :meth:`get_data` of the same dataflow in
more than 1 threads.
6. This dataflow does not support windows. Use `MultiProcessPrefetchData` which works on windows.
6. This dataflow does not support windows. Use `MultiProcessRunner` which works on windows.
7. (For Mac only) A UNIX named pipe will be created in the current directory.
However, certain non-local filesystem such as NFS/GlusterFS/AFS doesn't always support pipes.
You can change the directory by ``export TENSORPACK_PIPEDIR=/other/dir``.
......@@ -269,7 +279,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
class _Worker(mp.Process):
def __init__(self, ds, conn_name, hwm, idx):
super(PrefetchDataZMQ._Worker, self).__init__()
super(MultiProcessRunnerZMQ._Worker, self).__init__()
self.ds = ds
self.conn_name = conn_name
self.hwm = hwm
......@@ -293,21 +303,25 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
socket.close(0)
context.destroy(0)
def __init__(self, ds, nr_proc=1, hwm=50):
def __init__(self, ds, num_proc=1, hwm=50, nr_proc=None):
"""
Args:
ds (DataFlow): input DataFlow.
nr_proc (int): number of processes to use.
num_proc (int): number of processes to use.
hwm (int): the zmq "high-water mark" (queue size) for both sender and receiver.
nr_proc: deprecated
"""
super(PrefetchDataZMQ, self).__init__()
if nr_proc is not None:
log_deprecated("MultiProcessRunnerZMQ(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
super(MultiProcessRunnerZMQ, self).__init__()
self.ds = ds
self.nr_proc = nr_proc
self.num_proc = num_proc
self._hwm = hwm
if nr_proc > 1:
logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. "
if num_proc > 1:
logger.info("[MultiProcessRunnerZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.")
try:
self._size = ds.__len__()
......@@ -321,14 +335,14 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
return self.ds.__len__()
def __iter__(self):
with self._guard, _zmq_catch_error('PrefetchDataZMQ'):
with self._guard, _zmq_catch_error('MultiProcessRunnerZMQ'):
for k in itertools.count():
if self._size > 0 and k >= self._size:
break
yield self._recv()
def reset_state(self):
super(PrefetchDataZMQ, self).reset_state()
super(MultiProcessRunnerZMQ, self).reset_state()
self._guard = DataFlowReentrantGuard()
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
......@@ -336,32 +350,31 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
pipename = _get_pipe_name('dataflow')
_bind_guard(self.socket, pipename)
self._procs = [PrefetchDataZMQ._Worker(self.ds, pipename, self._hwm, idx)
for idx in range(self.nr_proc)]
self._procs = [MultiProcessRunnerZMQ._Worker(self.ds, pipename, self._hwm, idx)
for idx in range(self.num_proc)]
self._start_processes()
# TODO renamed to MultiThreadDataFlow if separated to a new project
class MultiThreadPrefetchData(DataFlow):
class MultiThreadRunner(DataFlow):
"""
Create multiple dataflow instances and run them each in one thread.
Collect outputs with a queue.
Collect outputs from them with a queue.
Note:
1. (Data integrity) An iterator cannot run faster automatically -- what's happening is
that each thread will create a dataflow iterator.
There will be ``nr_thread`` dataflow running in parallel and **independently**.
There will be ``num_thread`` dataflow running in parallel and **independently**.
As a result, we have the following guarantee on the dataflow correctness:
a. When ``nr_thread=1``, this dataflow produces the same data as the
a. When ``num_thread=1``, this dataflow produces the same data as the
given dataflow in the same order.
b. When ``nr_thread>1``, if each sample from the given dataflow is i.i.d.,
b. When ``num_thread>1``, if each sample from the given dataflow is i.i.d.,
then this dataflow produces the **same distribution** of data as the given dataflow.
This implies that there will be duplication, reordering, etc.
You probably only want to use it for training.
For example, if your original dataflow contains no randomness and produces the same first datapoint,
then after parallel prefetching, the datapoint will be produced ``nr_thread`` times
then after parallel prefetching, the datapoint will be produced ``num_thread`` times
at the beginning.
Even when your original dataflow is fully shuffled, you still need to be aware of the
`Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_
......@@ -373,7 +386,7 @@ class MultiThreadPrefetchData(DataFlow):
class _Worker(StoppableThread):
def __init__(self, get_df, queue):
super(MultiThreadPrefetchData._Worker, self).__init__()
super(MultiThreadRunner._Worker, self).__init__()
self.df = get_df()
assert isinstance(self.df, DataFlow), self.df
self.queue = queue
......@@ -395,21 +408,29 @@ class MultiThreadPrefetchData(DataFlow):
finally:
self.stop()
def __init__(self, get_df, nr_prefetch, nr_thread):
def __init__(self, get_df, num_prefetch=None, num_thread=None, nr_prefetch=None, nr_thread=None):
"""
Args:
get_df ( -> DataFlow): a callable which returns a DataFlow.
Each thread will call this function to get the DataFlow to use.
Therefore do not return the same DataFlow for each call.
nr_prefetch (int): size of the queue
nr_thread (int): number of threads
num_prefetch (int): size of the queue
num_thread (int): number of threads
nr_prefetch, nr_thread: deprecated names
"""
assert nr_thread > 0 and nr_prefetch > 0
self.nr_thread = nr_thread
self.queue = queue.Queue(maxsize=nr_prefetch)
if nr_prefetch is not None:
log_deprecated("MultiThreadRunner(nr_prefetch)", "Renamed to 'num_prefetch'", "2020-01-01")
num_prefetch = nr_prefetch
if nr_thread is not None:
log_deprecated("MultiThreadRunner(nr_thread)", "Renamed to 'num_thread'", "2020-01-01")
num_thread = nr_thread
assert num_thread > 0, num_thread
assert num_prefetch > 0, num_prefetch
self.num_thread = num_thread
self.queue = queue.Queue(maxsize=num_prefetch)
self.threads = [
MultiThreadPrefetchData._Worker(get_df, self.queue)
for _ in range(nr_thread)]
MultiThreadRunner._Worker(get_df, self.queue)
for _ in range(num_thread)]
def reset_state(self):
for th in self.threads:
......@@ -482,13 +503,19 @@ plasma = None
# PlasmaGetData = create_dummy_class('PlasmaGetData', 'pyarrow') # noqa
# The old inappropriate names:
PrefetchData = MultiProcessRunner
MultiProcessPrefetchData = MultiProcessRunner
PrefetchDataZMQ = MultiProcessRunnerZMQ
MultiThreadPrefetchData = MultiThreadRunner
if __name__ == '__main__':
import time
from .raw import DataFromGenerator
from .common import FixedSizeData
x = DataFromGenerator(itertools.count())
x = FixedSizeData(x, 100)
x = PrefetchDataZMQ(x, 2)
x = MultiProcessRunnerZMQ(x, 2)
x.reset_state()
for idx, dp in enumerate(x):
print(dp)
......
......@@ -10,11 +10,12 @@ from six.moves import queue
from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils.serialize import dumps, loads
from ..utils.develop import log_deprecated
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow
from .common import RepeatedData
from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error
__all__ = ['ThreadedMapData', 'MultiThreadMapData',
__all__ = ['MultiThreadMapData',
'MultiProcessMapData', 'MultiProcessMapDataZMQ']
......@@ -115,7 +116,7 @@ class MultiThreadMapData(_ParallelMapData):
1. You should avoid starting many threads in your main process to reduce GIL contention.
The threads will only start in the process which calls :meth:`reset_state()`.
Therefore you can use ``PrefetchDataZMQ(MultiThreadMapData(...), 1)``
Therefore you can use ``MultiProcessRunnerZMQ(MultiThreadMapData(...), 1)``
to reduce GIL contention.
"""
class _Worker(StoppableThread):
......@@ -143,16 +144,21 @@ class MultiThreadMapData(_ParallelMapData):
finally:
self.stop()
def __init__(self, ds, nr_thread, map_func, buffer_size=200, strict=False):
def __init__(self, ds, num_thread=None, map_func=None, buffer_size=200, strict=False, nr_thread=None):
"""
Args:
ds (DataFlow): the dataflow to map
nr_thread (int): number of threads to use
num_thread (int): number of threads to use
map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint.
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
nr_thread: deprecated name
"""
if nr_thread is not None:
log_deprecated("MultiThreadMapData(nr_thread)", "Renamed to 'num_thread'", "2020-01-01")
num_thread = nr_thread
if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints
try:
......@@ -161,10 +167,10 @@ class MultiThreadMapData(_ParallelMapData):
pass
super(MultiThreadMapData, self).__init__(ds, buffer_size, strict)
assert nr_thread > 0, nr_thread
assert num_thread > 0, num_thread
self._strict = strict
self.nr_thread = nr_thread
self.num_thread = num_thread
self.map_func = map_func
self._threads = []
self._evt = None
......@@ -181,7 +187,7 @@ class MultiThreadMapData(_ParallelMapData):
self._evt = threading.Event()
self._threads = [MultiThreadMapData._Worker(
self._in_queue, self._out_queue, self._evt, self.map_func)
for _ in range(self.nr_thread)]
for _ in range(self.num_thread)]
for t in self._threads:
t.start()
......@@ -211,10 +217,6 @@ class MultiThreadMapData(_ParallelMapData):
# logger.warn("Cannot join thread {}.".format(p.name))
# TODO deprecated
ThreadedMapData = MultiThreadMapData
class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
"""
Same as :class:`MapData`, but start processes to run the mapping function,
......@@ -255,16 +257,20 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
dp = self.map_func(dp)
socket.send(dumps(dp), copy=False)
def __init__(self, ds, nr_proc, map_func, buffer_size=200, strict=False):
def __init__(self, ds, num_proc=None, map_func=None, buffer_size=200, strict=False, nr_proc=None):
"""
Args:
ds (DataFlow): the dataflow to map
nr_proc(int): number of threads to use
num_proc(int): number of threads to use
map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint.
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
nr_proc: deprecated name
"""
if nr_proc is not None:
log_deprecated("MultiProcessMapDataZMQ(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints
try:
......@@ -274,8 +280,8 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
_ParallelMapData.__init__(self, ds, buffer_size, strict)
_MultiProcessZMQDataFlow.__init__(self)
assert nr_proc > 0, nr_proc
self.nr_proc = nr_proc
assert num_proc > 0, num_proc
self.num_proc = num_proc
self.map_func = map_func
self._strict = strict
self._procs = []
......@@ -291,11 +297,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
pipename = _get_pipe_name('dataflow-map')
_bind_guard(self.socket, pipename)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.nr_proc)]
worker_hwm = int(self._buffer_size * 2 // self.nr_proc)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.num_proc)]
worker_hwm = int(self._buffer_size * 2 // self.num_proc)
self._procs = [MultiProcessMapDataZMQ._Worker(
self._proc_ids[k], self.map_func, pipename, worker_hwm)
for k in range(self.nr_proc)]
for k in range(self.num_proc)]
self._start_processes()
self._fill_buffer() # pre-fill the bufer
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment