Commit dd2d9ffa authored by Yuxin Wu's avatar Yuxin Wu

Rename PrefetchData -> MultiProcessRunner

parent 0cecfbb6
......@@ -372,7 +372,6 @@ _DEPRECATED_NAMES = set([
# deprecated stuff:
'QueueInputTrainer',
'dump_dataflow_to_process_queue',
'PrefetchOnGPUs',
'DistributedTrainerReplicated',
'DistributedTrainerParameterServer',
'InputDesc',
......@@ -382,11 +381,14 @@ _DEPRECATED_NAMES = set([
'DumpTensor',
'DumpParamAsImage',
'get_nr_gpu',
'start_test', # TestDataSpeed
'ThreadedMapData',
'TrainingMonitor',
'PeakMemoryTracker',
'PrefetchData',
'MultiProcessPrefetchData',
'PrefetchDataZMQ',
'MultiThreadPrefetchData',
# deprecated or renamed symbolic code
'Deconv2D', 'psnr',
......
......@@ -3,11 +3,11 @@
### What is DataFlow
DataFlow is a library to build Python iterators for efficient data loading.
DataFlow is a pure-Python library to create iterators for efficient data loading.
**Definition**: A DataFlow is a idiomatic Python container object that has a `__iter__()` generator method,
which yields `datapoints` and optionally a `__len__()` method returning the size of the flow.
A datapoint is a **list** of Python objects which are called the `components` of a datapoint.
**Definition**: A DataFlow is a idiomatic Python iterator object that has a `__iter__()` method
which yields `datapoints`, and optionally a `__len__()` method returning the size of the DataFlow.
A datapoint is a **list or dict** of Python objects, each of which are called the `components` of a datapoint.
**Example**: to train on MNIST dataset, you may need a DataFlow with a `__iter__()` method
that yields datapoints (lists) of two components:
......@@ -21,12 +21,10 @@ You can simply use DataFlow as a data processing pipeline and plug it into any o
### Composition of DataFlow
One good thing about having a standard interface is to be able to provide
the greatest code reusability.
There are a lot of existing DataFlow utilities in tensorpack, which you can use to compose
DataFlow with complex data pipeline. A common pipeline usually
would __read from disk (or other sources), apply transformations, group into batches,
prefetch data__, etc. A simple example is as the following:
one DataFlow with complex data pipeline. A common pipeline usually
would __read from disk (or other sources), apply transformations (possibly in parallel), group into batches,
prefetch data__, etc, and all __run in parallel__. A simple example is as the following:
````python
# a DataFlow you implement to produce [tensor1, tensor2, ..] lists from whatever sources:
......@@ -36,17 +34,17 @@ df = AugmentImageComponent(df, [imgaug.Resize((225, 225))])
# group data into batches of size 128
df = BatchData(df, 128)
# start 3 processes to run the dataflow in parallel
df = PrefetchDataZMQ(df, 3)
df = MultiProcessRunnerZMQ(df, 3)
````
You can find more complicated DataFlow in the [ImageNet training script](../examples/ImageNetModels/imagenet_utils.py)
with all the data preprocessing.
### Work with Your Data
Unless you are working with standard data types (image folders, LMDB, etc),
you would usually want to write the source DataFlow (`MyDataFlow` in the above example) for your data format.
We do not make any assumptions about your data format.
You would usually want to write the source DataFlow (`MyDataFlow` in the above example) for your own data format.
See [another tutorial](extend/dataflow.html) for simple instructions on writing a DataFlow.
Once you have the source reader, all the [existing
DataFlows](../modules/dataflow.html) are ready for you to build up the rest of the data pipeline.
Once you have the source reader, all the [built-in
DataFlows](../modules/dataflow.html) are ready for you to assemble the rest of the data pipeline.
### Why DataFlow
......@@ -62,16 +60,16 @@ Nevertheless, tensorpack supports data loading with native TF operators / TF dat
### Use DataFlow in Your Own Code
Normally, tensorpack `InputSource` interface runs the DataFlow during training.
However, DataFlow can also be used without other tensorpack components.
If you need to run the DataFlow by yourself, call `reset_state()` first to initialize it,
When training with tensorpack, typically it is the `InputSource` interface that runs the DataFlow.
However, DataFlow can be used without other tensorpack components.
To run a DataFlow by yourself, call `reset_state()` first to initialize it,
and then use the generator however you like:
```python
df = SomeDataFlow()
df.reset_state()
for dp in df:
# dp is now a list. do whatever
# dp is now a list. do whatever
```
Read the [API documentation](../../modules/dataflow.html#tensorpack.dataflow.DataFlw)
......
......@@ -16,7 +16,7 @@ then apply complicated preprocessing to it.
We aim to reach a speed of, roughly **1k~3k images per second**, to keep GPUs busy.
Some things to know before reading:
1. For smaller datasets (e.g. several GBs of images with lightweight preprocessing), a simple reader plus some multiprocess prefetch should usually work well enough.
1. For smaller datasets (e.g. several GBs of images with lightweight preprocessing), a simple reader plus some multiprocess runner should usually work well enough.
Therefore you don't have to understand this tutorial in depth unless you really find your data being the bottleneck.
This tutorial could be a bit complicated for people new to system architectures, but you do need these to be able to run fast enough on ImageNet-scale dataset.
2. Having a fast Python generator **alone** may or may not improve your overall training speed.
......@@ -64,7 +64,7 @@ On a good filesystem you probably can already observe good speed here (e.g. 5 it
because we are doing heavy random read on the filesystem (regardless of whether `shuffle` is True).
Image decoding in `cv2.imread` could also be a bottleneck at this early stage.
### Parallel Prefetch
### Parallel Runner
We will now add the cheapest pre-processing now to get an ndarray in the end instead of a list
(because training will need ndarray eventually):
......@@ -84,15 +84,15 @@ Now it's time to add threads or processes:
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True)
ds1 = AugmentImageComponent(ds0, lots_of_augmentors)
ds = PrefetchDataZMQ(ds1, nr_proc=25)
ds = MultiProcessRunnerZMQ(ds1, num_proc=25)
ds = BatchData(ds, 256)
```
Here we fork 25 processes to run `ds1`, and collect their output through ZMQ IPC protocol,
which is faster than `multiprocessing.Queue`. You can also apply prefetch after batch, of course.
which is faster than `multiprocessing.Queue`. You can also apply parallel runner after batching, of course.
### Parallel Map
The above DataFlow might be fast, but since it forks the ImageNet reader (`ds0`),
it's **not a good idea to use it for validation** (for reasons mentioned at top. More details at the [documentation](../modules/dataflow.html#tensorpack.dataflow.PrefetchDataZMQ)).
it's **not a good idea to use it for validation** (for reasons mentioned at top. More details at the [documentation](../modules/dataflow.html#tensorpack.dataflow.MultiProcessRunnerZMQ)).
Alternatively, you can use multi-threaded preprocessing like this:
```eval_rst
......@@ -102,9 +102,9 @@ Alternatively, you can use multi-threaded preprocessing like this:
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors)
ds1 = MultiThreadMapData(
ds0, nr_thread=25,
ds0, num_thread=25,
map_func=lambda dp: [augmentor.augment(dp[0]), dp[1]], buffer_size=1000)
# ds1 = PrefetchDataZMQ(ds1, nr_proc=1)
# ds1 = MultiProcessRunnerZMQ(ds1, num_proc=1)
ds = BatchData(ds1, 256)
```
`MultiThreadMapData` launches a thread pool to fetch data and apply the mapping function on **a single
......@@ -127,11 +127,11 @@ If you identify this as a bottleneck, you can also use:
ds0 = dataset.ILSVRC12Files('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors)
ds1 = MultiThreadMapData(
ds0, nr_thread=25,
ds0, num_thread=25,
map_func=lambda dp:
[augmentor.augment(cv2.imread(dp[0], cv2.IMREAD_COLOR)), dp[1]],
buffer_size=1000)
ds1 = PrefetchDataZMQ(ds1, nr_proc=1)
ds1 = MultiProcessRunnerZMQ(ds1, num_proc=1)
ds = BatchData(ds1, 256)
```
......@@ -159,15 +159,15 @@ class BinaryILSVRC12(dataset.ILSVRC12Files):
jpeg = np.asarray(bytearray(jpeg), dtype='uint8')
yield [jpeg, label]
ds0 = BinaryILSVRC12('/path/to/ILSVRC/', 'train')
ds1 = PrefetchDataZMQ(ds0, nr_proc=1)
ds1 = MultiProcessRunnerZMQ(ds0, num_proc=1)
LMDBSerializer.save(ds1, '/path/to/ILSVRC-train.lmdb')
```
The above script builds a DataFlow which produces jpeg-encoded ImageNet data.
We store the jpeg string as a numpy array because the function `cv2.imdecode` later expect this format.
Please note we can only use 1 prefetch process to speed up. If `nr_proc>1`, `ds1` will take data
Please note we can only use 1 runner process to speed up. If `num_proc>1`, `ds1` will take data
from several forks of `ds0`, then neither the content nor the order of `ds1` will be the same as `ds0`.
See [documentation](../modules/dataflow.html#tensorpack.dataflow.PrefetchDataZMQ)
about caveats of `PrefetchDataZMQ`.
See [documentation](../modules/dataflow.html#tensorpack.dataflow.MultiProcessRunnerZMQ)
about caveats of `MultiProcessRunnerZMQ`.
It will generate a database file of 140G. We load the DataFlow back by reading this LMDB file sequentially:
```
......@@ -193,7 +193,7 @@ the added line above maintains a buffer of datapoints and shuffle them once a wh
It will not affect the model as long as the buffer is large enough,
but it can also consume much memory if too large.
### Augmentations & Parallel Prefetch
### Augmentations & Parallel Runner
Then we add necessary transformations:
```eval_rst
......@@ -218,24 +218,24 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like
ds = LMDBSerializer.load(db, shuffle=False)
ds = LocallyShuffleData(ds, 50000)
ds = PrefetchData(ds, 5000, 1)
ds = MultiProcessRunner(ds, 5000, 1)
ds = MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0)
ds = AugmentImageComponent(ds, lots_of_augmentors)
ds = PrefetchDataZMQ(ds, 25)
ds = MultiProcessRunnerZMQ(ds, 25)
ds = BatchData(ds, 256)
```
Since we are reading the database sequentially, having multiple forked instances of the
base LMDB reader will result in biased data distribution. Therefore we use `PrefetchData` to
base LMDB reader will result in biased data distribution. Therefore we use `MultiProcessRunner` to
launch the base DataFlow in only **one process**, and only parallelize the transformations
with another `PrefetchDataZMQ`
(Nesting two `PrefetchDataZMQ`, however, will result in a different behavior.
with another `MultiProcessRunnerZMQ`
(Nesting two `MultiProcessRunnerZMQ`, however, will result in a different behavior.
These differences are explained in the API documentation in more details.).
Similar to what we did earlier, you can use `MultiThreadMapData` to parallelize as well.
Let me summarize what this DataFlow does:
1. One process reads LMDB file, shuffle them in a buffer and put them into a `multiprocessing.Queue` (used by `PrefetchData`).
1. One process reads LMDB file, shuffle them in a buffer and put them into a `multiprocessing.Queue` (used by `MultiProcessRunner`).
2. 25 processes take items from the queue, decode and process them into [image, label] pairs, and
send them through ZMQ IPC pipe.
3. The main process takes data from the pipe, makes batches.
......
......@@ -82,7 +82,7 @@ def get_data(path, isTrain, stat_file):
ds = MapDataComponent(ds, lambda x: (x - mean) / std)
ds = TIMITBatch(ds, BATCH)
if isTrain:
ds = PrefetchDataZMQ(ds, 1)
ds = MultiProcessRunnerZMQ(ds, 1)
return ds
......
......@@ -32,7 +32,7 @@ def get_data():
]
data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128)
data_train = PrefetchData(data_train, 5, 5)
data_train = MultiProcessRunner(data_train, 5, 5)
augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors)
......
......@@ -148,7 +148,7 @@ def get_config():
]
data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128)
data_train = PrefetchDataZMQ(data_train, 5)
data_train = MultiProcessRunnerZMQ(data_train, 5)
augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors)
......
......@@ -225,7 +225,7 @@ def get_data():
ds = ThetaImages(ds)
ds = RepeatedData(ds, 50) # just pretend this dataset is bigger
# this pre-computation is pretty heavy
ds = PrefetchDataZMQ(ds, min(20, multiprocessing.cpu_count()))
ds = MultiProcessRunnerZMQ(ds, min(20, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH)
return ds
......
......@@ -9,7 +9,7 @@ from tabulate import tabulate
from termcolor import colored
from tensorpack.dataflow import (
DataFromList, MapData, MapDataComponent, MultiProcessMapDataZMQ, MultiThreadMapData,
DataFromList, MapData, MapDataComponent, MultiProcessMapData, MultiThreadMapData,
TestDataSpeed, imgaug)
from tensorpack.utils import logger
from tensorpack.utils.argtools import log_once, memoized
......@@ -392,7 +392,7 @@ def get_train_dataflow():
# MPI does not like fork()
else:
buffer_size = cfg.DATA.NUM_WORKERS * 20
ds = MultiProcessMapDataZMQ(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
ds = MultiProcessMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
else:
ds = MapData(ds, preprocess)
return ds
......
......@@ -177,7 +177,7 @@ def get_data(datadir, isTrain=True):
names = ['trainA', 'trainB'] if isTrain else ['testA', 'testB']
df = get_image_pairs(*[os.path.join(datadir, n) for n in names])
df = BatchData(df, BATCH if isTrain else TEST_BATCH)
df = PrefetchDataZMQ(df, 2 if isTrain else 1)
df = MultiProcessRunnerZMQ(df, 2 if isTrain else 1)
return df
......
......@@ -115,7 +115,7 @@ def get_data():
ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = AugmentImageComponent(ds, get_augmentors())
ds = BatchData(ds, args.batch)
ds = PrefetchDataZMQ(ds, 5)
ds = MultiProcessRunnerZMQ(ds, 5)
return ds
......
......@@ -186,7 +186,7 @@ def get_celebA_data(datadir, styleA, styleB=None):
imgaug.Resize(64)]
df = AugmentImageComponents(df, augs, (0, 1))
df = BatchData(df, BATCH)
df = PrefetchDataZMQ(df, 3)
df = MultiProcessRunnerZMQ(df, 3)
return df
......
......@@ -173,7 +173,7 @@ def get_data():
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH)
ds = PrefetchData(ds, 100, 1)
ds = MultiProcessRunner(ds, 100, 1)
return ds
......
......@@ -233,7 +233,7 @@ def get_data(name):
]
ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = BatchDataByShape(ds, 8, idx=0)
ds = PrefetchDataZMQ(ds, 1)
ds = MultiProcessRunnerZMQ(ds, 1)
else:
ds = BatchData(ds, 1)
return ds
......
......@@ -11,7 +11,9 @@ import tensorflow as tf
import tqdm
from tensorpack import ModelDesc
from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug
from tensorpack.dataflow import (
AugmentImageComponent, BatchData, MultiThreadMapData,
MultiProcessRunnerZMQ, dataset, imgaug)
from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.models import regularize_cost, l2_regularizer
from tensorpack.predict import FeedfreePredictor, PredictConfig
......@@ -88,7 +90,7 @@ def get_imagenet_dataflow(
ds = AugmentImageComponent(ds, augmentors, copy=False)
if parallel < 16:
logger.warn("DataFlow may become the bottleneck when too few processes are used.")
ds = PrefetchDataZMQ(ds, parallel)
ds = MultiProcessRunnerZMQ(ds, parallel)
ds = BatchData(ds, batch_size, remainder=False)
else:
ds = dataset.ILSVRC12Files(datadir, name, shuffle=False)
......@@ -101,7 +103,7 @@ def get_imagenet_dataflow(
return im, cls
ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1)
ds = MultiProcessRunnerZMQ(ds, 1)
return ds
......
......@@ -133,7 +133,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain:
ds = PrefetchData(ds, 3, 2)
ds = MultiProcessRunner(ds, 3, 2)
return ds
......
......@@ -68,7 +68,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors, copy=False)
if isTrain:
ds = PrefetchDataZMQ(ds, min(25, multiprocessing.cpu_count()))
ds = MultiProcessRunnerZMQ(ds, min(25, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
return ds
......
......@@ -254,7 +254,7 @@ def get_data(file_name):
imgaug.Flip(horiz=True)]
ds = AugmentImageComponent(ds, augmentors, index=0, copy=True)
ds = MapData(ds, lambda x: [cv2.resize(x[0], (32, 32), interpolation=cv2.INTER_CUBIC), x[0]])
ds = PrefetchDataZMQ(ds, 3)
ds = MultiProcessRunnerZMQ(ds, 3)
ds = BatchData(ds, BATCH_SIZE)
return ds
......
......@@ -103,7 +103,7 @@ def get_data(train_or_test, cifar_classnum):
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain:
ds = PrefetchDataZMQ(ds, 5)
ds = MultiProcessRunnerZMQ(ds, 5)
return ds
......
......@@ -78,7 +78,7 @@ def get_data():
]
data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128)
data_train = PrefetchData(data_train, 5, 5)
data_train = MultiProcessRunner(data_train, 5, 5)
augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors)
......
......@@ -37,7 +37,7 @@ def get_data(subset):
# something that yields [[SHAPE, SHAPE, CHANNELS], [1]]
ds = FakeData([[SHAPE, SHAPE, CHANNELS], [1]], 1000, random=False,
dtype=['float32', 'uint8'], domain=[(0, 255), (0, 10)])
ds = PrefetchDataZMQ(ds, 2)
ds = MultiProcessRunnerZMQ(ds, 2)
ds = BatchData(ds, BATCH_SIZE)
return ds
......
......@@ -64,29 +64,25 @@ class DataFlow(object):
@abstractmethod
def __iter__(self):
"""
* A dataflow is an iterable. The :meth:`__iter__` method should yield a list each time.
Each element in the list should be either a number or a numpy array.
For now, tensorpack also **partially** supports dict instead of list.
* A dataflow is an iterable. The :meth:`__iter__` method should yield a list or dict each time.
Note that dict is **partially** supported at the moment: certain dataflow does not support dict.
* The :meth:`__iter__` method can be either finite (will stop iteration) or infinite
(will not stop iteration). For a finite dataflow, :meth:`__iter__` can be called
again after the previous call returned.
again immediately after the previous call returned.
* For many dataflow, the :meth:`__iter__` method is non-reentrant, which means for an dataflow
instance ``df``, :meth:`df.__iter__` cannot be called before the previous
:meth:`df.__iter__` call has finished (iteration has stopped).
When it is non-reentrant, :meth:`df.__iter__` should throw an exception if
When a dataflow is non-reentrant, :meth:`df.__iter__` should throw an exception if
called before the previous call has finished.
For such non-reentrant dataflows, if you need to use the same dataflow in two places,
you need to create two dataflow instances.
Yields:
list: The datapoint, i.e. list of components.
list/dict: The datapoint, i.e. list/dict of components.
"""
def get_data(self):
return self.__iter__()
def __len__(self):
"""
* A dataflow can optionally implement :meth:`__len__`. If not implemented, it will
......@@ -95,7 +91,7 @@ class DataFlow(object):
* It returns an integer representing the size of the dataflow.
The return value **may not be accurate or meaningful** at all.
When saying the length is "accurate", it means that
:meth:`__iter__` will always yield this many of datapoints.
:meth:`__iter__` will always yield this many of datapoints before it stops iteration.
* There could be many reasons why :meth:`__len__` is inaccurate.
For example, some dataflow has dynamic size, if it throws away datapoints on the fly.
......@@ -103,8 +99,9 @@ class DataFlow(object):
the dataset, due to parallelism and buffering.
In this case it does not make sense to stop the iteration anywhere.
* Due to the above reasons, the length is only a rough guidance. Inside
tensorpack it's only used in these places:
* Due to the above reasons, the length is only a rough guidance.
And it's up to the user how to interpret it.
Inside tensorpack it's only used in these places:
+ A default ``steps_per_epoch`` in training, but you probably want to customize
it yourself, especially when using data-parallel trainer.
......@@ -121,9 +118,6 @@ class DataFlow(object):
"""
raise NotImplementedError()
def size(self):
return self.__len__()
def reset_state(self):
"""
* The caller must guarantee that :meth:`reset_state` should be called **once and only once**
......@@ -134,21 +128,28 @@ class DataFlow(object):
e.g., initialize random number generators (RNG), create worker processes.
Because it's very common to use RNG in data processing,
developers of dataflow can also subclass :class:`RNGDataFlow` to have easier access to an RNG.
developers of dataflow can also subclass :class:`RNGDataFlow` to have easier access to
a properly-initialized RNG.
* A dataflow is not fork-safe after :meth:`reset_state` is called (because this will violate the guarantee).
A few number of dataflow is not fork-safe anytime, which will be mentioned in the docs.
There are a few other dataflows that are not fork-safe anytime, which will be mentioned in the docs.
* You should take the responsibility and follow the above guarantee if you're the caller of a dataflow yourself
(either when you're using dataflow outside of tensorpack, or if you're writing a wrapper dataflow).
* Tensorpack's built-in forking dataflows (:class:`MultiProcessPrefetchData`, :class:`MultiProcessMapData`, etc)
* Tensorpack's built-in forking dataflows (:class:`MultiProcessRunner`, :class:`MultiProcessMapData`, etc)
and other component that uses dataflows (:class:`InputSource`)
already take care of the responsibility of calling this method.
* You should take the responsibility and follow the above guarantee if you're the caller of a dataflow yourself
(either if you're using dtaflow outside of tensorpack,
or if you're writing a wrapper dataflow).
"""
pass
# These are the old (overly verbose) names for the methods:
def get_data(self):
return self.__iter__()
def size(self):
return self.__len__()
class RNGDataFlow(DataFlow):
""" A DataFlow with RNG"""
......@@ -156,7 +157,7 @@ class RNGDataFlow(DataFlow):
rng = None
"""
``self.rng`` is a ``np.random.RandomState`` instance that is initialized
correctly in ``RNGDataFlow.reset_state()``.
correctly (with different seeds in each process) in ``RNGDataFlow.reset_state()``.
"""
def reset_state(self):
......
This diff is collapsed.
......@@ -166,7 +166,7 @@ class LMDBDataDecoder(MapData):
def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
"""
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf.
Read a Caffe-format LMDB file where each value contains a ``caffe.Datum`` protobuf.
Produces datapoints of the format: [HWC image, label].
Note that Caffe LMDB format is not efficient: it stores serialized raw
......@@ -175,9 +175,6 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`.
Returns:
a :class:`LMDBDataDecoder` instance.
Example:
.. code-block:: python
......
......@@ -92,7 +92,7 @@ class AugmentImageComponent(MapDataComponent):
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
index (int): the index of the image component to be augmented in the datapoint.
index (int or str): the index or key of the image component to be augmented in the datapoint.
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
......@@ -134,8 +134,8 @@ class AugmentImageCoordinates(MapData):
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
img_index (int): the index of the image component to be augmented.
coords_index (int): the index of the coordinate component to be augmented.
img_index (int or str): the index/key of the image component to be augmented.
coords_index (int or str): the index/key of the coordinate component to be augmented.
copy, catch_exceptions: same as in :class:`AugmentImageComponent`
"""
if isinstance(augmentors, AugmentorList):
......
This diff is collapsed.
......@@ -10,11 +10,12 @@ from six.moves import queue
from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils.serialize import dumps, loads
from ..utils.develop import log_deprecated
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow
from .common import RepeatedData
from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error
__all__ = ['ThreadedMapData', 'MultiThreadMapData',
__all__ = ['MultiThreadMapData',
'MultiProcessMapData', 'MultiProcessMapDataZMQ']
......@@ -115,7 +116,7 @@ class MultiThreadMapData(_ParallelMapData):
1. You should avoid starting many threads in your main process to reduce GIL contention.
The threads will only start in the process which calls :meth:`reset_state()`.
Therefore you can use ``PrefetchDataZMQ(MultiThreadMapData(...), 1)``
Therefore you can use ``MultiProcessRunnerZMQ(MultiThreadMapData(...), 1)``
to reduce GIL contention.
"""
class _Worker(StoppableThread):
......@@ -143,16 +144,21 @@ class MultiThreadMapData(_ParallelMapData):
finally:
self.stop()
def __init__(self, ds, nr_thread, map_func, buffer_size=200, strict=False):
def __init__(self, ds, num_thread=None, map_func=None, buffer_size=200, strict=False, nr_thread=None):
"""
Args:
ds (DataFlow): the dataflow to map
nr_thread (int): number of threads to use
num_thread (int): number of threads to use
map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint.
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
nr_thread: deprecated name
"""
if nr_thread is not None:
log_deprecated("MultiThreadMapData(nr_thread)", "Renamed to 'num_thread'", "2020-01-01")
num_thread = nr_thread
if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints
try:
......@@ -161,10 +167,10 @@ class MultiThreadMapData(_ParallelMapData):
pass
super(MultiThreadMapData, self).__init__(ds, buffer_size, strict)
assert nr_thread > 0, nr_thread
assert num_thread > 0, num_thread
self._strict = strict
self.nr_thread = nr_thread
self.num_thread = num_thread
self.map_func = map_func
self._threads = []
self._evt = None
......@@ -181,7 +187,7 @@ class MultiThreadMapData(_ParallelMapData):
self._evt = threading.Event()
self._threads = [MultiThreadMapData._Worker(
self._in_queue, self._out_queue, self._evt, self.map_func)
for _ in range(self.nr_thread)]
for _ in range(self.num_thread)]
for t in self._threads:
t.start()
......@@ -211,10 +217,6 @@ class MultiThreadMapData(_ParallelMapData):
# logger.warn("Cannot join thread {}.".format(p.name))
# TODO deprecated
ThreadedMapData = MultiThreadMapData
class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
"""
Same as :class:`MapData`, but start processes to run the mapping function,
......@@ -255,16 +257,20 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
dp = self.map_func(dp)
socket.send(dumps(dp), copy=False)
def __init__(self, ds, nr_proc, map_func, buffer_size=200, strict=False):
def __init__(self, ds, num_proc=None, map_func=None, buffer_size=200, strict=False, nr_proc=None):
"""
Args:
ds (DataFlow): the dataflow to map
nr_proc(int): number of threads to use
num_proc(int): number of threads to use
map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint.
buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above.
nr_proc: deprecated name
"""
if nr_proc is not None:
log_deprecated("MultiProcessMapDataZMQ(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints
try:
......@@ -274,8 +280,8 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
_ParallelMapData.__init__(self, ds, buffer_size, strict)
_MultiProcessZMQDataFlow.__init__(self)
assert nr_proc > 0, nr_proc
self.nr_proc = nr_proc
assert num_proc > 0, num_proc
self.num_proc = num_proc
self.map_func = map_func
self._strict = strict
self._procs = []
......@@ -291,11 +297,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
pipename = _get_pipe_name('dataflow-map')
_bind_guard(self.socket, pipename)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.nr_proc)]
worker_hwm = int(self._buffer_size * 2 // self.nr_proc)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.num_proc)]
worker_hwm = int(self._buffer_size * 2 // self.num_proc)
self._procs = [MultiProcessMapDataZMQ._Worker(
self._proc_ids[k], self.map_func, pipename, worker_hwm)
for k in range(self.nr_proc)]
for k in range(self.num_proc)]
self._start_processes()
self._fill_buffer() # pre-fill the bufer
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment