Commit dd2d9ffa authored by Yuxin Wu's avatar Yuxin Wu

Rename PrefetchData -> MultiProcessRunner

parent 0cecfbb6
...@@ -372,7 +372,6 @@ _DEPRECATED_NAMES = set([ ...@@ -372,7 +372,6 @@ _DEPRECATED_NAMES = set([
# deprecated stuff: # deprecated stuff:
'QueueInputTrainer', 'QueueInputTrainer',
'dump_dataflow_to_process_queue', 'dump_dataflow_to_process_queue',
'PrefetchOnGPUs',
'DistributedTrainerReplicated', 'DistributedTrainerReplicated',
'DistributedTrainerParameterServer', 'DistributedTrainerParameterServer',
'InputDesc', 'InputDesc',
...@@ -382,11 +381,14 @@ _DEPRECATED_NAMES = set([ ...@@ -382,11 +381,14 @@ _DEPRECATED_NAMES = set([
'DumpTensor', 'DumpTensor',
'DumpParamAsImage', 'DumpParamAsImage',
'get_nr_gpu', 'get_nr_gpu',
'start_test', # TestDataSpeed
'ThreadedMapData',
'TrainingMonitor', 'TrainingMonitor',
'PeakMemoryTracker', 'PeakMemoryTracker',
'PrefetchData',
'MultiProcessPrefetchData',
'PrefetchDataZMQ',
'MultiThreadPrefetchData',
# deprecated or renamed symbolic code # deprecated or renamed symbolic code
'Deconv2D', 'psnr', 'Deconv2D', 'psnr',
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
### What is DataFlow ### What is DataFlow
DataFlow is a library to build Python iterators for efficient data loading. DataFlow is a pure-Python library to create iterators for efficient data loading.
**Definition**: A DataFlow is a idiomatic Python container object that has a `__iter__()` generator method, **Definition**: A DataFlow is a idiomatic Python iterator object that has a `__iter__()` method
which yields `datapoints` and optionally a `__len__()` method returning the size of the flow. which yields `datapoints`, and optionally a `__len__()` method returning the size of the DataFlow.
A datapoint is a **list** of Python objects which are called the `components` of a datapoint. A datapoint is a **list or dict** of Python objects, each of which are called the `components` of a datapoint.
**Example**: to train on MNIST dataset, you may need a DataFlow with a `__iter__()` method **Example**: to train on MNIST dataset, you may need a DataFlow with a `__iter__()` method
that yields datapoints (lists) of two components: that yields datapoints (lists) of two components:
...@@ -21,12 +21,10 @@ You can simply use DataFlow as a data processing pipeline and plug it into any o ...@@ -21,12 +21,10 @@ You can simply use DataFlow as a data processing pipeline and plug it into any o
### Composition of DataFlow ### Composition of DataFlow
One good thing about having a standard interface is to be able to provide
the greatest code reusability.
There are a lot of existing DataFlow utilities in tensorpack, which you can use to compose There are a lot of existing DataFlow utilities in tensorpack, which you can use to compose
DataFlow with complex data pipeline. A common pipeline usually one DataFlow with complex data pipeline. A common pipeline usually
would __read from disk (or other sources), apply transformations, group into batches, would __read from disk (or other sources), apply transformations (possibly in parallel), group into batches,
prefetch data__, etc. A simple example is as the following: prefetch data__, etc, and all __run in parallel__. A simple example is as the following:
````python ````python
# a DataFlow you implement to produce [tensor1, tensor2, ..] lists from whatever sources: # a DataFlow you implement to produce [tensor1, tensor2, ..] lists from whatever sources:
...@@ -36,17 +34,17 @@ df = AugmentImageComponent(df, [imgaug.Resize((225, 225))]) ...@@ -36,17 +34,17 @@ df = AugmentImageComponent(df, [imgaug.Resize((225, 225))])
# group data into batches of size 128 # group data into batches of size 128
df = BatchData(df, 128) df = BatchData(df, 128)
# start 3 processes to run the dataflow in parallel # start 3 processes to run the dataflow in parallel
df = PrefetchDataZMQ(df, 3) df = MultiProcessRunnerZMQ(df, 3)
```` ````
You can find more complicated DataFlow in the [ImageNet training script](../examples/ImageNetModels/imagenet_utils.py) You can find more complicated DataFlow in the [ImageNet training script](../examples/ImageNetModels/imagenet_utils.py)
with all the data preprocessing. with all the data preprocessing.
### Work with Your Data ### Work with Your Data
Unless you are working with standard data types (image folders, LMDB, etc), We do not make any assumptions about your data format.
you would usually want to write the source DataFlow (`MyDataFlow` in the above example) for your data format. You would usually want to write the source DataFlow (`MyDataFlow` in the above example) for your own data format.
See [another tutorial](extend/dataflow.html) for simple instructions on writing a DataFlow. See [another tutorial](extend/dataflow.html) for simple instructions on writing a DataFlow.
Once you have the source reader, all the [existing Once you have the source reader, all the [built-in
DataFlows](../modules/dataflow.html) are ready for you to build up the rest of the data pipeline. DataFlows](../modules/dataflow.html) are ready for you to assemble the rest of the data pipeline.
### Why DataFlow ### Why DataFlow
...@@ -62,9 +60,9 @@ Nevertheless, tensorpack supports data loading with native TF operators / TF dat ...@@ -62,9 +60,9 @@ Nevertheless, tensorpack supports data loading with native TF operators / TF dat
### Use DataFlow in Your Own Code ### Use DataFlow in Your Own Code
Normally, tensorpack `InputSource` interface runs the DataFlow during training. When training with tensorpack, typically it is the `InputSource` interface that runs the DataFlow.
However, DataFlow can also be used without other tensorpack components. However, DataFlow can be used without other tensorpack components.
If you need to run the DataFlow by yourself, call `reset_state()` first to initialize it, To run a DataFlow by yourself, call `reset_state()` first to initialize it,
and then use the generator however you like: and then use the generator however you like:
```python ```python
df = SomeDataFlow() df = SomeDataFlow()
......
...@@ -16,7 +16,7 @@ then apply complicated preprocessing to it. ...@@ -16,7 +16,7 @@ then apply complicated preprocessing to it.
We aim to reach a speed of, roughly **1k~3k images per second**, to keep GPUs busy. We aim to reach a speed of, roughly **1k~3k images per second**, to keep GPUs busy.
Some things to know before reading: Some things to know before reading:
1. For smaller datasets (e.g. several GBs of images with lightweight preprocessing), a simple reader plus some multiprocess prefetch should usually work well enough. 1. For smaller datasets (e.g. several GBs of images with lightweight preprocessing), a simple reader plus some multiprocess runner should usually work well enough.
Therefore you don't have to understand this tutorial in depth unless you really find your data being the bottleneck. Therefore you don't have to understand this tutorial in depth unless you really find your data being the bottleneck.
This tutorial could be a bit complicated for people new to system architectures, but you do need these to be able to run fast enough on ImageNet-scale dataset. This tutorial could be a bit complicated for people new to system architectures, but you do need these to be able to run fast enough on ImageNet-scale dataset.
2. Having a fast Python generator **alone** may or may not improve your overall training speed. 2. Having a fast Python generator **alone** may or may not improve your overall training speed.
...@@ -64,7 +64,7 @@ On a good filesystem you probably can already observe good speed here (e.g. 5 it ...@@ -64,7 +64,7 @@ On a good filesystem you probably can already observe good speed here (e.g. 5 it
because we are doing heavy random read on the filesystem (regardless of whether `shuffle` is True). because we are doing heavy random read on the filesystem (regardless of whether `shuffle` is True).
Image decoding in `cv2.imread` could also be a bottleneck at this early stage. Image decoding in `cv2.imread` could also be a bottleneck at this early stage.
### Parallel Prefetch ### Parallel Runner
We will now add the cheapest pre-processing now to get an ndarray in the end instead of a list We will now add the cheapest pre-processing now to get an ndarray in the end instead of a list
(because training will need ndarray eventually): (because training will need ndarray eventually):
...@@ -84,15 +84,15 @@ Now it's time to add threads or processes: ...@@ -84,15 +84,15 @@ Now it's time to add threads or processes:
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True) ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True)
ds1 = AugmentImageComponent(ds0, lots_of_augmentors) ds1 = AugmentImageComponent(ds0, lots_of_augmentors)
ds = PrefetchDataZMQ(ds1, nr_proc=25) ds = MultiProcessRunnerZMQ(ds1, num_proc=25)
ds = BatchData(ds, 256) ds = BatchData(ds, 256)
``` ```
Here we fork 25 processes to run `ds1`, and collect their output through ZMQ IPC protocol, Here we fork 25 processes to run `ds1`, and collect their output through ZMQ IPC protocol,
which is faster than `multiprocessing.Queue`. You can also apply prefetch after batch, of course. which is faster than `multiprocessing.Queue`. You can also apply parallel runner after batching, of course.
### Parallel Map ### Parallel Map
The above DataFlow might be fast, but since it forks the ImageNet reader (`ds0`), The above DataFlow might be fast, but since it forks the ImageNet reader (`ds0`),
it's **not a good idea to use it for validation** (for reasons mentioned at top. More details at the [documentation](../modules/dataflow.html#tensorpack.dataflow.PrefetchDataZMQ)). it's **not a good idea to use it for validation** (for reasons mentioned at top. More details at the [documentation](../modules/dataflow.html#tensorpack.dataflow.MultiProcessRunnerZMQ)).
Alternatively, you can use multi-threaded preprocessing like this: Alternatively, you can use multi-threaded preprocessing like this:
```eval_rst ```eval_rst
...@@ -102,9 +102,9 @@ Alternatively, you can use multi-threaded preprocessing like this: ...@@ -102,9 +102,9 @@ Alternatively, you can use multi-threaded preprocessing like this:
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True) ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors) augmentor = AugmentorList(lots_of_augmentors)
ds1 = MultiThreadMapData( ds1 = MultiThreadMapData(
ds0, nr_thread=25, ds0, num_thread=25,
map_func=lambda dp: [augmentor.augment(dp[0]), dp[1]], buffer_size=1000) map_func=lambda dp: [augmentor.augment(dp[0]), dp[1]], buffer_size=1000)
# ds1 = PrefetchDataZMQ(ds1, nr_proc=1) # ds1 = MultiProcessRunnerZMQ(ds1, num_proc=1)
ds = BatchData(ds1, 256) ds = BatchData(ds1, 256)
``` ```
`MultiThreadMapData` launches a thread pool to fetch data and apply the mapping function on **a single `MultiThreadMapData` launches a thread pool to fetch data and apply the mapping function on **a single
...@@ -127,11 +127,11 @@ If you identify this as a bottleneck, you can also use: ...@@ -127,11 +127,11 @@ If you identify this as a bottleneck, you can also use:
ds0 = dataset.ILSVRC12Files('/path/to/ILSVRC12', 'train', shuffle=True) ds0 = dataset.ILSVRC12Files('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors) augmentor = AugmentorList(lots_of_augmentors)
ds1 = MultiThreadMapData( ds1 = MultiThreadMapData(
ds0, nr_thread=25, ds0, num_thread=25,
map_func=lambda dp: map_func=lambda dp:
[augmentor.augment(cv2.imread(dp[0], cv2.IMREAD_COLOR)), dp[1]], [augmentor.augment(cv2.imread(dp[0], cv2.IMREAD_COLOR)), dp[1]],
buffer_size=1000) buffer_size=1000)
ds1 = PrefetchDataZMQ(ds1, nr_proc=1) ds1 = MultiProcessRunnerZMQ(ds1, num_proc=1)
ds = BatchData(ds1, 256) ds = BatchData(ds1, 256)
``` ```
...@@ -159,15 +159,15 @@ class BinaryILSVRC12(dataset.ILSVRC12Files): ...@@ -159,15 +159,15 @@ class BinaryILSVRC12(dataset.ILSVRC12Files):
jpeg = np.asarray(bytearray(jpeg), dtype='uint8') jpeg = np.asarray(bytearray(jpeg), dtype='uint8')
yield [jpeg, label] yield [jpeg, label]
ds0 = BinaryILSVRC12('/path/to/ILSVRC/', 'train') ds0 = BinaryILSVRC12('/path/to/ILSVRC/', 'train')
ds1 = PrefetchDataZMQ(ds0, nr_proc=1) ds1 = MultiProcessRunnerZMQ(ds0, num_proc=1)
LMDBSerializer.save(ds1, '/path/to/ILSVRC-train.lmdb') LMDBSerializer.save(ds1, '/path/to/ILSVRC-train.lmdb')
``` ```
The above script builds a DataFlow which produces jpeg-encoded ImageNet data. The above script builds a DataFlow which produces jpeg-encoded ImageNet data.
We store the jpeg string as a numpy array because the function `cv2.imdecode` later expect this format. We store the jpeg string as a numpy array because the function `cv2.imdecode` later expect this format.
Please note we can only use 1 prefetch process to speed up. If `nr_proc>1`, `ds1` will take data Please note we can only use 1 runner process to speed up. If `num_proc>1`, `ds1` will take data
from several forks of `ds0`, then neither the content nor the order of `ds1` will be the same as `ds0`. from several forks of `ds0`, then neither the content nor the order of `ds1` will be the same as `ds0`.
See [documentation](../modules/dataflow.html#tensorpack.dataflow.PrefetchDataZMQ) See [documentation](../modules/dataflow.html#tensorpack.dataflow.MultiProcessRunnerZMQ)
about caveats of `PrefetchDataZMQ`. about caveats of `MultiProcessRunnerZMQ`.
It will generate a database file of 140G. We load the DataFlow back by reading this LMDB file sequentially: It will generate a database file of 140G. We load the DataFlow back by reading this LMDB file sequentially:
``` ```
...@@ -193,7 +193,7 @@ the added line above maintains a buffer of datapoints and shuffle them once a wh ...@@ -193,7 +193,7 @@ the added line above maintains a buffer of datapoints and shuffle them once a wh
It will not affect the model as long as the buffer is large enough, It will not affect the model as long as the buffer is large enough,
but it can also consume much memory if too large. but it can also consume much memory if too large.
### Augmentations & Parallel Prefetch ### Augmentations & Parallel Runner
Then we add necessary transformations: Then we add necessary transformations:
```eval_rst ```eval_rst
...@@ -218,24 +218,24 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like ...@@ -218,24 +218,24 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like
ds = LMDBSerializer.load(db, shuffle=False) ds = LMDBSerializer.load(db, shuffle=False)
ds = LocallyShuffleData(ds, 50000) ds = LocallyShuffleData(ds, 50000)
ds = PrefetchData(ds, 5000, 1) ds = MultiProcessRunner(ds, 5000, 1)
ds = MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0) ds = MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0)
ds = AugmentImageComponent(ds, lots_of_augmentors) ds = AugmentImageComponent(ds, lots_of_augmentors)
ds = PrefetchDataZMQ(ds, 25) ds = MultiProcessRunnerZMQ(ds, 25)
ds = BatchData(ds, 256) ds = BatchData(ds, 256)
``` ```
Since we are reading the database sequentially, having multiple forked instances of the Since we are reading the database sequentially, having multiple forked instances of the
base LMDB reader will result in biased data distribution. Therefore we use `PrefetchData` to base LMDB reader will result in biased data distribution. Therefore we use `MultiProcessRunner` to
launch the base DataFlow in only **one process**, and only parallelize the transformations launch the base DataFlow in only **one process**, and only parallelize the transformations
with another `PrefetchDataZMQ` with another `MultiProcessRunnerZMQ`
(Nesting two `PrefetchDataZMQ`, however, will result in a different behavior. (Nesting two `MultiProcessRunnerZMQ`, however, will result in a different behavior.
These differences are explained in the API documentation in more details.). These differences are explained in the API documentation in more details.).
Similar to what we did earlier, you can use `MultiThreadMapData` to parallelize as well. Similar to what we did earlier, you can use `MultiThreadMapData` to parallelize as well.
Let me summarize what this DataFlow does: Let me summarize what this DataFlow does:
1. One process reads LMDB file, shuffle them in a buffer and put them into a `multiprocessing.Queue` (used by `PrefetchData`). 1. One process reads LMDB file, shuffle them in a buffer and put them into a `multiprocessing.Queue` (used by `MultiProcessRunner`).
2. 25 processes take items from the queue, decode and process them into [image, label] pairs, and 2. 25 processes take items from the queue, decode and process them into [image, label] pairs, and
send them through ZMQ IPC pipe. send them through ZMQ IPC pipe.
3. The main process takes data from the pipe, makes batches. 3. The main process takes data from the pipe, makes batches.
......
...@@ -82,7 +82,7 @@ def get_data(path, isTrain, stat_file): ...@@ -82,7 +82,7 @@ def get_data(path, isTrain, stat_file):
ds = MapDataComponent(ds, lambda x: (x - mean) / std) ds = MapDataComponent(ds, lambda x: (x - mean) / std)
ds = TIMITBatch(ds, BATCH) ds = TIMITBatch(ds, BATCH)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, 1) ds = MultiProcessRunnerZMQ(ds, 1)
return ds return ds
......
...@@ -32,7 +32,7 @@ def get_data(): ...@@ -32,7 +32,7 @@ def get_data():
] ]
data_train = AugmentImageComponent(data_train, augmentors) data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128) data_train = BatchData(data_train, 128)
data_train = PrefetchData(data_train, 5, 5) data_train = MultiProcessRunner(data_train, 5, 5)
augmentors = [imgaug.Resize((40, 40))] augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors) data_test = AugmentImageComponent(data_test, augmentors)
......
...@@ -148,7 +148,7 @@ def get_config(): ...@@ -148,7 +148,7 @@ def get_config():
] ]
data_train = AugmentImageComponent(data_train, augmentors) data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128) data_train = BatchData(data_train, 128)
data_train = PrefetchDataZMQ(data_train, 5) data_train = MultiProcessRunnerZMQ(data_train, 5)
augmentors = [imgaug.Resize((40, 40))] augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors) data_test = AugmentImageComponent(data_test, augmentors)
......
...@@ -225,7 +225,7 @@ def get_data(): ...@@ -225,7 +225,7 @@ def get_data():
ds = ThetaImages(ds) ds = ThetaImages(ds)
ds = RepeatedData(ds, 50) # just pretend this dataset is bigger ds = RepeatedData(ds, 50) # just pretend this dataset is bigger
# this pre-computation is pretty heavy # this pre-computation is pretty heavy
ds = PrefetchDataZMQ(ds, min(20, multiprocessing.cpu_count())) ds = MultiProcessRunnerZMQ(ds, min(20, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH) ds = BatchData(ds, BATCH)
return ds return ds
......
...@@ -9,7 +9,7 @@ from tabulate import tabulate ...@@ -9,7 +9,7 @@ from tabulate import tabulate
from termcolor import colored from termcolor import colored
from tensorpack.dataflow import ( from tensorpack.dataflow import (
DataFromList, MapData, MapDataComponent, MultiProcessMapDataZMQ, MultiThreadMapData, DataFromList, MapData, MapDataComponent, MultiProcessMapData, MultiThreadMapData,
TestDataSpeed, imgaug) TestDataSpeed, imgaug)
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.argtools import log_once, memoized from tensorpack.utils.argtools import log_once, memoized
...@@ -392,7 +392,7 @@ def get_train_dataflow(): ...@@ -392,7 +392,7 @@ def get_train_dataflow():
# MPI does not like fork() # MPI does not like fork()
else: else:
buffer_size = cfg.DATA.NUM_WORKERS * 20 buffer_size = cfg.DATA.NUM_WORKERS * 20
ds = MultiProcessMapDataZMQ(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size) ds = MultiProcessMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
else: else:
ds = MapData(ds, preprocess) ds = MapData(ds, preprocess)
return ds return ds
......
...@@ -177,7 +177,7 @@ def get_data(datadir, isTrain=True): ...@@ -177,7 +177,7 @@ def get_data(datadir, isTrain=True):
names = ['trainA', 'trainB'] if isTrain else ['testA', 'testB'] names = ['trainA', 'trainB'] if isTrain else ['testA', 'testB']
df = get_image_pairs(*[os.path.join(datadir, n) for n in names]) df = get_image_pairs(*[os.path.join(datadir, n) for n in names])
df = BatchData(df, BATCH if isTrain else TEST_BATCH) df = BatchData(df, BATCH if isTrain else TEST_BATCH)
df = PrefetchDataZMQ(df, 2 if isTrain else 1) df = MultiProcessRunnerZMQ(df, 2 if isTrain else 1)
return df return df
......
...@@ -115,7 +115,7 @@ def get_data(): ...@@ -115,7 +115,7 @@ def get_data():
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = AugmentImageComponent(ds, get_augmentors()) ds = AugmentImageComponent(ds, get_augmentors())
ds = BatchData(ds, args.batch) ds = BatchData(ds, args.batch)
ds = PrefetchDataZMQ(ds, 5) ds = MultiProcessRunnerZMQ(ds, 5)
return ds return ds
......
...@@ -186,7 +186,7 @@ def get_celebA_data(datadir, styleA, styleB=None): ...@@ -186,7 +186,7 @@ def get_celebA_data(datadir, styleA, styleB=None):
imgaug.Resize(64)] imgaug.Resize(64)]
df = AugmentImageComponents(df, augs, (0, 1)) df = AugmentImageComponents(df, augs, (0, 1))
df = BatchData(df, BATCH) df = BatchData(df, BATCH)
df = PrefetchDataZMQ(df, 3) df = MultiProcessRunnerZMQ(df, 3)
return df return df
......
...@@ -173,7 +173,7 @@ def get_data(): ...@@ -173,7 +173,7 @@ def get_data():
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)] augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
ds = AugmentImageComponents(ds, augs, (0, 1)) ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH) ds = BatchData(ds, BATCH)
ds = PrefetchData(ds, 100, 1) ds = MultiProcessRunner(ds, 100, 1)
return ds return ds
......
...@@ -233,7 +233,7 @@ def get_data(name): ...@@ -233,7 +233,7 @@ def get_data(name):
] ]
ds = AugmentImageComponent(ds, augmentors, copy=False) ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = BatchDataByShape(ds, 8, idx=0) ds = BatchDataByShape(ds, 8, idx=0)
ds = PrefetchDataZMQ(ds, 1) ds = MultiProcessRunnerZMQ(ds, 1)
else: else:
ds = BatchData(ds, 1) ds = BatchData(ds, 1)
return ds return ds
......
...@@ -11,7 +11,9 @@ import tensorflow as tf ...@@ -11,7 +11,9 @@ import tensorflow as tf
import tqdm import tqdm
from tensorpack import ModelDesc from tensorpack import ModelDesc
from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug from tensorpack.dataflow import (
AugmentImageComponent, BatchData, MultiThreadMapData,
MultiProcessRunnerZMQ, dataset, imgaug)
from tensorpack.input_source import QueueInput, StagingInput from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.models import regularize_cost, l2_regularizer from tensorpack.models import regularize_cost, l2_regularizer
from tensorpack.predict import FeedfreePredictor, PredictConfig from tensorpack.predict import FeedfreePredictor, PredictConfig
...@@ -88,7 +90,7 @@ def get_imagenet_dataflow( ...@@ -88,7 +90,7 @@ def get_imagenet_dataflow(
ds = AugmentImageComponent(ds, augmentors, copy=False) ds = AugmentImageComponent(ds, augmentors, copy=False)
if parallel < 16: if parallel < 16:
logger.warn("DataFlow may become the bottleneck when too few processes are used.") logger.warn("DataFlow may become the bottleneck when too few processes are used.")
ds = PrefetchDataZMQ(ds, parallel) ds = MultiProcessRunnerZMQ(ds, parallel)
ds = BatchData(ds, batch_size, remainder=False) ds = BatchData(ds, batch_size, remainder=False)
else: else:
ds = dataset.ILSVRC12Files(datadir, name, shuffle=False) ds = dataset.ILSVRC12Files(datadir, name, shuffle=False)
...@@ -101,7 +103,7 @@ def get_imagenet_dataflow( ...@@ -101,7 +103,7 @@ def get_imagenet_dataflow(
return im, cls return im, cls
ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True) ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
ds = BatchData(ds, batch_size, remainder=True) ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1) ds = MultiProcessRunnerZMQ(ds, 1)
return ds return ds
......
...@@ -133,7 +133,7 @@ def get_data(train_or_test): ...@@ -133,7 +133,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain) ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchData(ds, 3, 2) ds = MultiProcessRunner(ds, 3, 2)
return ds return ds
......
...@@ -68,7 +68,7 @@ def get_data(train_or_test): ...@@ -68,7 +68,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors, copy=False) ds = AugmentImageComponent(ds, augmentors, copy=False)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, min(25, multiprocessing.cpu_count())) ds = MultiProcessRunnerZMQ(ds, min(25, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain) ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
return ds return ds
......
...@@ -254,7 +254,7 @@ def get_data(file_name): ...@@ -254,7 +254,7 @@ def get_data(file_name):
imgaug.Flip(horiz=True)] imgaug.Flip(horiz=True)]
ds = AugmentImageComponent(ds, augmentors, index=0, copy=True) ds = AugmentImageComponent(ds, augmentors, index=0, copy=True)
ds = MapData(ds, lambda x: [cv2.resize(x[0], (32, 32), interpolation=cv2.INTER_CUBIC), x[0]]) ds = MapData(ds, lambda x: [cv2.resize(x[0], (32, 32), interpolation=cv2.INTER_CUBIC), x[0]])
ds = PrefetchDataZMQ(ds, 3) ds = MultiProcessRunnerZMQ(ds, 3)
ds = BatchData(ds, BATCH_SIZE) ds = BatchData(ds, BATCH_SIZE)
return ds return ds
......
...@@ -103,7 +103,7 @@ def get_data(train_or_test, cifar_classnum): ...@@ -103,7 +103,7 @@ def get_data(train_or_test, cifar_classnum):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain) ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, 5) ds = MultiProcessRunnerZMQ(ds, 5)
return ds return ds
......
...@@ -78,7 +78,7 @@ def get_data(): ...@@ -78,7 +78,7 @@ def get_data():
] ]
data_train = AugmentImageComponent(data_train, augmentors) data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128) data_train = BatchData(data_train, 128)
data_train = PrefetchData(data_train, 5, 5) data_train = MultiProcessRunner(data_train, 5, 5)
augmentors = [imgaug.Resize((40, 40))] augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors) data_test = AugmentImageComponent(data_test, augmentors)
......
...@@ -37,7 +37,7 @@ def get_data(subset): ...@@ -37,7 +37,7 @@ def get_data(subset):
# something that yields [[SHAPE, SHAPE, CHANNELS], [1]] # something that yields [[SHAPE, SHAPE, CHANNELS], [1]]
ds = FakeData([[SHAPE, SHAPE, CHANNELS], [1]], 1000, random=False, ds = FakeData([[SHAPE, SHAPE, CHANNELS], [1]], 1000, random=False,
dtype=['float32', 'uint8'], domain=[(0, 255), (0, 10)]) dtype=['float32', 'uint8'], domain=[(0, 255), (0, 10)])
ds = PrefetchDataZMQ(ds, 2) ds = MultiProcessRunnerZMQ(ds, 2)
ds = BatchData(ds, BATCH_SIZE) ds = BatchData(ds, BATCH_SIZE)
return ds return ds
......
...@@ -64,29 +64,25 @@ class DataFlow(object): ...@@ -64,29 +64,25 @@ class DataFlow(object):
@abstractmethod @abstractmethod
def __iter__(self): def __iter__(self):
""" """
* A dataflow is an iterable. The :meth:`__iter__` method should yield a list each time. * A dataflow is an iterable. The :meth:`__iter__` method should yield a list or dict each time.
Each element in the list should be either a number or a numpy array. Note that dict is **partially** supported at the moment: certain dataflow does not support dict.
For now, tensorpack also **partially** supports dict instead of list.
* The :meth:`__iter__` method can be either finite (will stop iteration) or infinite * The :meth:`__iter__` method can be either finite (will stop iteration) or infinite
(will not stop iteration). For a finite dataflow, :meth:`__iter__` can be called (will not stop iteration). For a finite dataflow, :meth:`__iter__` can be called
again after the previous call returned. again immediately after the previous call returned.
* For many dataflow, the :meth:`__iter__` method is non-reentrant, which means for an dataflow * For many dataflow, the :meth:`__iter__` method is non-reentrant, which means for an dataflow
instance ``df``, :meth:`df.__iter__` cannot be called before the previous instance ``df``, :meth:`df.__iter__` cannot be called before the previous
:meth:`df.__iter__` call has finished (iteration has stopped). :meth:`df.__iter__` call has finished (iteration has stopped).
When it is non-reentrant, :meth:`df.__iter__` should throw an exception if When a dataflow is non-reentrant, :meth:`df.__iter__` should throw an exception if
called before the previous call has finished. called before the previous call has finished.
For such non-reentrant dataflows, if you need to use the same dataflow in two places, For such non-reentrant dataflows, if you need to use the same dataflow in two places,
you need to create two dataflow instances. you need to create two dataflow instances.
Yields: Yields:
list: The datapoint, i.e. list of components. list/dict: The datapoint, i.e. list/dict of components.
""" """
def get_data(self):
return self.__iter__()
def __len__(self): def __len__(self):
""" """
* A dataflow can optionally implement :meth:`__len__`. If not implemented, it will * A dataflow can optionally implement :meth:`__len__`. If not implemented, it will
...@@ -95,7 +91,7 @@ class DataFlow(object): ...@@ -95,7 +91,7 @@ class DataFlow(object):
* It returns an integer representing the size of the dataflow. * It returns an integer representing the size of the dataflow.
The return value **may not be accurate or meaningful** at all. The return value **may not be accurate or meaningful** at all.
When saying the length is "accurate", it means that When saying the length is "accurate", it means that
:meth:`__iter__` will always yield this many of datapoints. :meth:`__iter__` will always yield this many of datapoints before it stops iteration.
* There could be many reasons why :meth:`__len__` is inaccurate. * There could be many reasons why :meth:`__len__` is inaccurate.
For example, some dataflow has dynamic size, if it throws away datapoints on the fly. For example, some dataflow has dynamic size, if it throws away datapoints on the fly.
...@@ -103,8 +99,9 @@ class DataFlow(object): ...@@ -103,8 +99,9 @@ class DataFlow(object):
the dataset, due to parallelism and buffering. the dataset, due to parallelism and buffering.
In this case it does not make sense to stop the iteration anywhere. In this case it does not make sense to stop the iteration anywhere.
* Due to the above reasons, the length is only a rough guidance. Inside * Due to the above reasons, the length is only a rough guidance.
tensorpack it's only used in these places: And it's up to the user how to interpret it.
Inside tensorpack it's only used in these places:
+ A default ``steps_per_epoch`` in training, but you probably want to customize + A default ``steps_per_epoch`` in training, but you probably want to customize
it yourself, especially when using data-parallel trainer. it yourself, especially when using data-parallel trainer.
...@@ -121,9 +118,6 @@ class DataFlow(object): ...@@ -121,9 +118,6 @@ class DataFlow(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def size(self):
return self.__len__()
def reset_state(self): def reset_state(self):
""" """
* The caller must guarantee that :meth:`reset_state` should be called **once and only once** * The caller must guarantee that :meth:`reset_state` should be called **once and only once**
...@@ -134,21 +128,28 @@ class DataFlow(object): ...@@ -134,21 +128,28 @@ class DataFlow(object):
e.g., initialize random number generators (RNG), create worker processes. e.g., initialize random number generators (RNG), create worker processes.
Because it's very common to use RNG in data processing, Because it's very common to use RNG in data processing,
developers of dataflow can also subclass :class:`RNGDataFlow` to have easier access to an RNG. developers of dataflow can also subclass :class:`RNGDataFlow` to have easier access to
a properly-initialized RNG.
* A dataflow is not fork-safe after :meth:`reset_state` is called (because this will violate the guarantee). * A dataflow is not fork-safe after :meth:`reset_state` is called (because this will violate the guarantee).
A few number of dataflow is not fork-safe anytime, which will be mentioned in the docs. There are a few other dataflows that are not fork-safe anytime, which will be mentioned in the docs.
* You should take the responsibility and follow the above guarantee if you're the caller of a dataflow yourself
(either when you're using dataflow outside of tensorpack, or if you're writing a wrapper dataflow).
* Tensorpack's built-in forking dataflows (:class:`MultiProcessPrefetchData`, :class:`MultiProcessMapData`, etc) * Tensorpack's built-in forking dataflows (:class:`MultiProcessRunner`, :class:`MultiProcessMapData`, etc)
and other component that uses dataflows (:class:`InputSource`) and other component that uses dataflows (:class:`InputSource`)
already take care of the responsibility of calling this method. already take care of the responsibility of calling this method.
* You should take the responsibility and follow the above guarantee if you're the caller of a dataflow yourself
(either if you're using dtaflow outside of tensorpack,
or if you're writing a wrapper dataflow).
""" """
pass pass
# These are the old (overly verbose) names for the methods:
def get_data(self):
return self.__iter__()
def size(self):
return self.__len__()
class RNGDataFlow(DataFlow): class RNGDataFlow(DataFlow):
""" A DataFlow with RNG""" """ A DataFlow with RNG"""
...@@ -156,7 +157,7 @@ class RNGDataFlow(DataFlow): ...@@ -156,7 +157,7 @@ class RNGDataFlow(DataFlow):
rng = None rng = None
""" """
``self.rng`` is a ``np.random.RandomState`` instance that is initialized ``self.rng`` is a ``np.random.RandomState`` instance that is initialized
correctly in ``RNGDataFlow.reset_state()``. correctly (with different seeds in each process) in ``RNGDataFlow.reset_state()``.
""" """
def reset_state(self): def reset_state(self):
......
...@@ -14,6 +14,7 @@ from termcolor import colored ...@@ -14,6 +14,7 @@ from termcolor import colored
from ..utils import logger from ..utils import logger
from ..utils.utils import get_rng, get_tqdm, get_tqdm_kwargs from ..utils.utils import get_rng, get_tqdm, get_tqdm_kwargs
from ..utils.develop import log_deprecated
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow, RNGDataFlow from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow, RNGDataFlow
__all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData', __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
...@@ -23,7 +24,7 @@ __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'Fixed ...@@ -23,7 +24,7 @@ __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'Fixed
class TestDataSpeed(ProxyDataFlow): class TestDataSpeed(ProxyDataFlow):
""" Test the speed of some DataFlow """ """ Test the speed of a DataFlow """
def __init__(self, ds, size=5000, warmup=0): def __init__(self, ds, size=5000, warmup=0):
""" """
Args: Args:
...@@ -175,7 +176,7 @@ class BatchDataByShape(BatchData): ...@@ -175,7 +176,7 @@ class BatchDataByShape(BatchData):
Note: Note:
It is implemented by a dict{shape -> datapoints}. It is implemented by a dict{shape -> datapoints}.
Datapoints of uncommon shapes may never be enough to form a batch and Therefore, datapoints of uncommon shapes may never be enough to form a batch and
never get generated. never get generated.
""" """
def __init__(self, ds, batch_size, idx): def __init__(self, ds, batch_size, idx):
...@@ -184,7 +185,7 @@ class BatchDataByShape(BatchData): ...@@ -184,7 +185,7 @@ class BatchDataByShape(BatchData):
ds (DataFlow): input DataFlow. ``dp[idx]`` has to be an :class:`np.ndarray`. ds (DataFlow): input DataFlow. ``dp[idx]`` has to be an :class:`np.ndarray`.
batch_size (int): batch size batch_size (int): batch size
idx (int): ``dp[idx].shape`` will be used to group datapoints. idx (int): ``dp[idx].shape`` will be used to group datapoints.
Other components are assumed to have the same shape. Other components are assumed to be batch-able.
""" """
super(BatchDataByShape, self).__init__(ds, batch_size, remainder=False) super(BatchDataByShape, self).__init__(ds, batch_size, remainder=False)
self.idx = idx self.idx = idx
...@@ -267,13 +268,13 @@ class MapData(ProxyDataFlow): ...@@ -267,13 +268,13 @@ class MapData(ProxyDataFlow):
Note: Note:
1. Please make sure func doesn't modify its arguments in place, 1. Please make sure func doesn't modify its arguments in place,
unless you're certain it's safe. unless you're certain it's safe.
2. If you discard some datapoints, ``len(ds)`` will be incorrect. 2. If you discard some datapoints, ``len(MapData(ds))`` will be incorrect.
Example: Example:
.. code-block:: none .. code-block:: none
ds = Mnist('train) ds = Mnist('train') # each datapoint is [img, label]
ds = MapData(ds, lambda dp: [dp[0] * 255, dp[1]]) ds = MapData(ds, lambda dp: [dp[0] * 255, dp[1]])
""" """
...@@ -302,14 +303,14 @@ class MapDataComponent(MapData): ...@@ -302,14 +303,14 @@ class MapDataComponent(MapData):
1. This dataflow itself doesn't modify the datapoints. 1. This dataflow itself doesn't modify the datapoints.
But please make sure func doesn't modify its arguments in place, But please make sure func doesn't modify its arguments in place,
unless you're certain it's safe. unless you're certain it's safe.
2. If you discard some datapoints, ``len(ds)`` will be incorrect. 2. If you discard some datapoints, ``len(MapDataComponent(ds, ..))`` will be incorrect.
Example: Example:
.. code-block:: none .. code-block:: none
ds = Mnist('train) ds = Mnist('train') # each datapoint is [img, label]
ds = MapDataComponent(ds, lambda img: img * 255, 0) ds = MapDataComponent(ds, lambda img: img * 255, 0) # map the 0th component
""" """
def __init__(self, ds, func, index=0): def __init__(self, ds, func, index=0):
""" """
...@@ -340,32 +341,32 @@ class RepeatedData(ProxyDataFlow): ...@@ -340,32 +341,32 @@ class RepeatedData(ProxyDataFlow):
dp1, dp2, .... dpn, dp1, dp2, ....dpn dp1, dp2, .... dpn, dp1, dp2, ....dpn
""" """
def __init__(self, ds, nr): def __init__(self, ds, num):
""" """
Args: Args:
ds (DataFlow): input DataFlow ds (DataFlow): input DataFlow
nr (int): number of times to repeat ds. num (int): number of times to repeat ds.
Set to -1 to repeat ``ds`` infinite times. Set to -1 to repeat ``ds`` infinite times.
""" """
self.nr = nr self.num = num
super(RepeatedData, self).__init__(ds) super(RepeatedData, self).__init__(ds)
def __len__(self): def __len__(self):
""" """
Raises: Raises:
:class:`ValueError` when nr == -1. :class:`ValueError` when num == -1.
""" """
if self.nr == -1: if self.num == -1:
raise NotImplementedError("__len__() is unavailable for infinite dataflow") raise NotImplementedError("__len__() is unavailable for infinite dataflow")
return len(self.ds) * self.nr return len(self.ds) * self.num
def __iter__(self): def __iter__(self):
if self.nr == -1: if self.num == -1:
while True: while True:
for dp in self.ds: for dp in self.ds:
yield dp yield dp
else: else:
for _ in range(self.nr): for _ in range(self.num):
for dp in self.ds: for dp in self.ds:
yield dp yield dp
...@@ -376,22 +377,22 @@ class RepeatedDataPoint(ProxyDataFlow): ...@@ -376,22 +377,22 @@ class RepeatedDataPoint(ProxyDataFlow):
dp1, dp1, ..., dp1, dp2, ..., dp2, ... dp1, dp1, ..., dp1, dp2, ..., dp2, ...
""" """
def __init__(self, ds, nr): def __init__(self, ds, num):
""" """
Args: Args:
ds (DataFlow): input DataFlow ds (DataFlow): input DataFlow
nr (int): number of times to repeat each datapoint. num (int): number of times to repeat each datapoint.
""" """
self.nr = int(nr) self.num = int(num)
assert self.nr >= 1, self.nr assert self.num >= 1, self.num
super(RepeatedDataPoint, self).__init__(ds) super(RepeatedDataPoint, self).__init__(ds)
def __len__(self): def __len__(self):
return len(self.ds) * self.nr return len(self.ds) * self.num
def __iter__(self): def __iter__(self):
for dp in self.ds: for dp in self.ds:
for _ in range(self.nr): for _ in range(self.num):
yield dp yield dp
...@@ -474,7 +475,7 @@ class RandomMixData(RNGDataFlow): ...@@ -474,7 +475,7 @@ class RandomMixData(RNGDataFlow):
class ConcatData(DataFlow): class ConcatData(DataFlow):
""" """
Concatenate several DataFlow. Concatenate several DataFlow.
Produce datapoints from each DataFlow and go to the next when one Produce datapoints from each DataFlow and start the next when one
DataFlow is exhausted. DataFlow is exhausted.
""" """
...@@ -501,8 +502,8 @@ class ConcatData(DataFlow): ...@@ -501,8 +502,8 @@ class ConcatData(DataFlow):
class JoinData(DataFlow): class JoinData(DataFlow):
""" """
Join the components from each DataFlow. See below for its behavior. Join the components from each DataFlow. See below for its behavior.
Dataflow that produces lists and dataflow that produces dicts
cannot be joined. Note that you can't join a DataFlow that produces lists with one that produces dicts.
Example: Example:
...@@ -524,7 +525,7 @@ class JoinData(DataFlow): ...@@ -524,7 +525,7 @@ class JoinData(DataFlow):
When these dataflows have different sizes, JoinData will stop when any When these dataflows have different sizes, JoinData will stop when any
of them is exhausted. of them is exhausted.
The list could contain the same DataFlow instance more than once, The list could contain the same DataFlow instance more than once,
but note that `__iter__` will then also be called many times. but note that in that case `__iter__` will then also be called many times.
""" """
self.df_lists = df_lists self.df_lists = df_lists
...@@ -568,7 +569,7 @@ def SelectComponent(ds, idxs): ...@@ -568,7 +569,7 @@ def SelectComponent(ds, idxs):
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
idxs (list[int]): a list of component indices. idxs (list[int] or list[str]): a list of component indices/keys.
Example: Example:
...@@ -583,13 +584,13 @@ def SelectComponent(ds, idxs): ...@@ -583,13 +584,13 @@ def SelectComponent(ds, idxs):
class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
""" Buffer the datapoints from a given dataflow, and shuffle them before producing them. """ Buffer the datapoints from a given dataflow, and shuffle them before producing them.
This can be used as an alternative when a complete random read is too expensive This can be used as an alternative when a complete random shuffle is too expensive
or impossible for the data source. or impossible for the data source.
This dataflow has the following behavior: This dataflow has the following behavior:
1. It takes datapoints from the given dataflow `ds` to an internal buffer of fixed size. 1. It takes datapoints from the given dataflow `ds` to an internal buffer of fixed size.
Each datapoint is duplicated for `nr_reuse` times. Each datapoint is duplicated for `num_reuse` times.
2. Once the buffer is full, this dataflow starts to yield data from the beginning of the buffer, 2. Once the buffer is full, this dataflow starts to yield data from the beginning of the buffer,
and new datapoints will be added to the end of the buffer. This is like a FIFO queue. and new datapoints will be added to the end of the buffer. This is like a FIFO queue.
3. The internal buffer is shuffled after every `shuffle_interval` datapoints that come from `ds`. 3. The internal buffer is shuffled after every `shuffle_interval` datapoints that come from `ds`.
...@@ -601,24 +602,28 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -601,24 +602,28 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
because it does not make sense to stop the iteration anywhere. because it does not make sense to stop the iteration anywhere.
""" """
def __init__(self, ds, buffer_size, nr_reuse=1, shuffle_interval=None): def __init__(self, ds, buffer_size, num_reuse=1, shuffle_interval=None, nr_reuse=None):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
buffer_size (int): size of the buffer. buffer_size (int): size of the buffer.
nr_reuse (int): duplicate each datapoints several times into the buffer to improve num_reuse (int): duplicate each datapoints several times into the buffer to improve
speed, but duplication may hurt your model. speed, but duplication may hurt your model.
shuffle_interval (int): shuffle the buffer after this many shuffle_interval (int): shuffle the buffer after this many
datapoints were produced from the given dataflow. Frequent shuffle on large buffer datapoints were produced from the given dataflow. Frequent shuffle on large buffer
may affect speed, but infrequent shuffle may not provide enough randomness. may affect speed, but infrequent shuffle may not provide enough randomness.
Defaults to buffer_size / 3 Defaults to buffer_size / 3
nr_reuse: deprecated name for num_reuse
""" """
if nr_reuse is not None:
log_deprecated("LocallyShuffleData(nr_reuse=...)", "Renamed to 'num_reuse'.", "2020-01-01")
num_reuse = nr_reuse
ProxyDataFlow.__init__(self, ds) ProxyDataFlow.__init__(self, ds)
self.q = deque(maxlen=buffer_size) self.q = deque(maxlen=buffer_size)
if shuffle_interval is None: if shuffle_interval is None:
shuffle_interval = int(buffer_size // 3) shuffle_interval = int(buffer_size // 3)
self.shuffle_interval = shuffle_interval self.shuffle_interval = shuffle_interval
self.nr_reuse = nr_reuse self.num_reuse = num_reuse
self._inf_ds = RepeatedData(ds, -1) self._inf_ds = RepeatedData(ds, -1)
def reset_state(self): def reset_state(self):
...@@ -629,7 +634,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -629,7 +634,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
self._inf_iter = iter(self._inf_ds) self._inf_iter = iter(self._inf_ds)
def __len__(self): def __len__(self):
return len(self.ds) * self.nr_reuse return len(self.ds) * self.num_reuse
def __iter__(self): def __iter__(self):
with self._guard: with self._guard:
...@@ -638,7 +643,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -638,7 +643,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
# fill queue # fill queue
if self._iter_cnt == 0: if self._iter_cnt == 0:
self.rng.shuffle(self.q) self.rng.shuffle(self.q)
for _ in range(self.nr_reuse): for _ in range(self.num_reuse):
if self.q.maxlen == len(self.q): if self.q.maxlen == len(self.q):
yield self.q.popleft() yield self.q.popleft()
self.q.append(dp) self.q.append(dp)
...@@ -646,7 +651,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -646,7 +651,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
class CacheData(ProxyDataFlow): class CacheData(ProxyDataFlow):
""" """
Cache the first pass of a DataFlow completely in memory, Completely cache the first pass of a DataFlow in memory,
and produce from the cache thereafter. and produce from the cache thereafter.
NOTE: The user should not stop the iterator before it has reached the end. NOTE: The user should not stop the iterator before it has reached the end.
...@@ -656,7 +661,7 @@ class CacheData(ProxyDataFlow): ...@@ -656,7 +661,7 @@ class CacheData(ProxyDataFlow):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
shuffle (bool): whether to shuffle the datapoints before producing them. shuffle (bool): whether to shuffle the cache before yielding from it.
""" """
self.shuffle = shuffle self.shuffle = shuffle
super(CacheData, self).__init__(ds) super(CacheData, self).__init__(ds)
...@@ -683,10 +688,11 @@ class CacheData(ProxyDataFlow): ...@@ -683,10 +688,11 @@ class CacheData(ProxyDataFlow):
class PrintData(ProxyDataFlow): class PrintData(ProxyDataFlow):
""" """
Behave like an identity mapping, but print shape and range of the first few datapoints. Behave like an identity proxy, but print shape and range of the first few datapoints.
Good for debugging.
Example: Example:
To enable this debugging output, you should place it somewhere in your dataflow like Place it somewhere in your dataflow like
.. code-block:: python .. code-block:: python
......
...@@ -166,7 +166,7 @@ class LMDBDataDecoder(MapData): ...@@ -166,7 +166,7 @@ class LMDBDataDecoder(MapData):
def CaffeLMDB(lmdb_path, shuffle=True, keys=None): def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
""" """
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf. Read a Caffe-format LMDB file where each value contains a ``caffe.Datum`` protobuf.
Produces datapoints of the format: [HWC image, label]. Produces datapoints of the format: [HWC image, label].
Note that Caffe LMDB format is not efficient: it stores serialized raw Note that Caffe LMDB format is not efficient: it stores serialized raw
...@@ -175,9 +175,6 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None): ...@@ -175,9 +175,6 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
Args: Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`. lmdb_path, shuffle, keys: same as :class:`LMDBData`.
Returns:
a :class:`LMDBDataDecoder` instance.
Example: Example:
.. code-block:: python .. code-block:: python
......
...@@ -92,7 +92,7 @@ class AugmentImageComponent(MapDataComponent): ...@@ -92,7 +92,7 @@ class AugmentImageComponent(MapDataComponent):
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
index (int): the index of the image component to be augmented in the datapoint. index (int or str): the index or key of the image component to be augmented in the datapoint.
copy (bool): Some augmentors modify the input images. When copy is copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied, True, a copy will be made before any augmentors are applied,
to keep the original images not modified. to keep the original images not modified.
...@@ -134,8 +134,8 @@ class AugmentImageCoordinates(MapData): ...@@ -134,8 +134,8 @@ class AugmentImageCoordinates(MapData):
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
img_index (int): the index of the image component to be augmented. img_index (int or str): the index/key of the image component to be augmented.
coords_index (int): the index of the coordinate component to be augmented. coords_index (int or str): the index/key of the coordinate component to be augmented.
copy, catch_exceptions: same as in :class:`AugmentImageComponent` copy, catch_exceptions: same as in :class:`AugmentImageComponent`
""" """
if isinstance(augmentors, AugmentorList): if isinstance(augmentors, AugmentorList):
......
...@@ -14,12 +14,14 @@ import zmq ...@@ -14,12 +14,14 @@ import zmq
from six.moves import queue, range from six.moves import queue, range
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.concurrency import ( from ..utils.concurrency import (
StoppableThread, enable_death_signal, ensure_proc_terminate, start_proc_mask_signal) StoppableThread, enable_death_signal, ensure_proc_terminate, start_proc_mask_signal)
from ..utils.serialize import dumps, loads from ..utils.serialize import dumps, loads
from .base import DataFlow, DataFlowReentrantGuard, DataFlowTerminated, ProxyDataFlow from .base import DataFlow, DataFlowReentrantGuard, DataFlowTerminated, ProxyDataFlow
__all__ = ['PrefetchData', 'MultiProcessPrefetchData', __all__ = ['PrefetchData', 'MultiProcessPrefetchData',
'MultiProcessRunner', 'MultiProcessRunnerZMQ', 'MultiThreadRunner',
'PrefetchDataZMQ', 'MultiThreadPrefetchData'] 'PrefetchDataZMQ', 'MultiThreadPrefetchData']
...@@ -35,7 +37,7 @@ def _bind_guard(sock, name): ...@@ -35,7 +37,7 @@ def _bind_guard(sock, name):
except zmq.ZMQError: except zmq.ZMQError:
logger.error( logger.error(
"ZMQError in socket.bind('{}'). Perhaps you're \ "ZMQError in socket.bind('{}'). Perhaps you're \
using pipes on a non-local file system. See documentation of PrefetchDataZMQ \ using pipes on a non-local file system. See documentation of MultiProcessRunnerZMQ \
for more information.".format(name)) for more information.".format(name))
raise raise
...@@ -118,27 +120,27 @@ class _MultiProcessZMQDataFlow(DataFlow): ...@@ -118,27 +120,27 @@ class _MultiProcessZMQDataFlow(DataFlow):
pass pass
class MultiProcessPrefetchData(ProxyDataFlow): class MultiProcessRunner(ProxyDataFlow):
""" """
Prefetch data from a DataFlow using Python multiprocessing utilities. Running a DataFlow in >=1 processes using Python multiprocessing utilities.
It will fork the process calling :meth:`__init__`, collect datapoints from `ds` in each It will fork the process that calls :meth:`__init__`, collect datapoints from `ds` in each
process by a Python :class:`multiprocessing.Queue`. process by a Python :class:`multiprocessing.Queue`.
Note: Note:
1. (Data integrity) An iterator cannot run faster automatically -- what's happening is 1. (Data integrity) An iterator cannot run faster automatically -- what's happening is
that the process will be forked ``nr_proc`` times. that the process will be forked ``num_proc`` times.
There will be ``nr_proc`` dataflow running in parallel and **independently**. There will be ``num_proc`` dataflow running in parallel and **independently**.
As a result, we have the following guarantee on the dataflow correctness: As a result, we have the following guarantee on the dataflow correctness:
a. When ``nr_proc=1``, this dataflow produces the same data as the a. When ``num_proc=1``, this dataflow produces the same data as the
given dataflow in the same order. given dataflow in the same order.
b. When ``nr_proc>1``, if each sample from the given dataflow is i.i.d., b. When ``num_proc>1``, if each sample from the given dataflow is i.i.d.,
then this dataflow produces the **same distribution** of data as the given dataflow. then this dataflow produces the **same distribution** of data as the given dataflow.
This implies that there will be duplication, reordering, etc. This implies that there will be duplication, reordering, etc.
You probably only want to use it for training. You probably only want to use it for training.
For example, if your original dataflow contains no randomness and produces the same first datapoint, For example, if your original dataflow contains no randomness and produces the same first datapoint,
then after parallel prefetching, the datapoint will be produced ``nr_proc`` times then after parallel prefetching, the datapoint will be produced ``num_proc`` times
at the beginning. at the beginning.
Even when your original dataflow is fully shuffled, you still need to be aware of the Even when your original dataflow is fully shuffled, you still need to be aware of the
`Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_ `Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_
...@@ -146,10 +148,11 @@ class MultiProcessPrefetchData(ProxyDataFlow): ...@@ -146,10 +148,11 @@ class MultiProcessPrefetchData(ProxyDataFlow):
To utilize parallelism with more strict data integrity, you can use To utilize parallelism with more strict data integrity, you can use
the parallel versions of :class:`MapData`: :class:`MultiThreadMapData`, :class:`MultiProcessMapData`. the parallel versions of :class:`MapData`: :class:`MultiThreadMapData`, :class:`MultiProcessMapData`.
2. This has more serialization overhead than :class:`PrefetchDataZMQ` when data is large. 2. This has more serialization overhead than :class:`MultiProcessRunnerZMQ` when data is large.
3. You can nest like this: ``PrefetchDataZMQ(PrefetchData(df, nr_proc=a), nr_proc=b)``. 3. You can nest like this: ``MultiProcessRunnerZMQ(MultiProcessRunner(df, num_proc=a), num_proc=b)``.
A total of ``a`` instances of ``df`` worker processes will be created. A total of ``a`` instances of ``df`` worker processes will be created.
4. fork happens in `__init__`. `reset_state()` is a no-op. The worker processes won't get called. 4. Fork happens in `__init__`. `reset_state()` is a no-op.
DataFlow in the worker processes will be reset at the time of fork.
5. This DataFlow does support windows. However, Windows requires more strict picklability on processes, 5. This DataFlow does support windows. However, Windows requires more strict picklability on processes,
which means that some code that's forkable on Linux may not be forkable on Windows. If that happens you'll which means that some code that's forkable on Linux may not be forkable on Windows. If that happens you'll
need to re-organize some part of code that's not forkable. need to re-organize some part of code that's not forkable.
...@@ -157,7 +160,7 @@ class MultiProcessPrefetchData(ProxyDataFlow): ...@@ -157,7 +160,7 @@ class MultiProcessPrefetchData(ProxyDataFlow):
class _Worker(mp.Process): class _Worker(mp.Process):
def __init__(self, ds, queue, idx): def __init__(self, ds, queue, idx):
super(MultiProcessPrefetchData._Worker, self).__init__() super(MultiProcessRunner._Worker, self).__init__()
self.ds = ds self.ds = ds
self.queue = queue self.queue = queue
self.idx = idx self.idx = idx
...@@ -170,33 +173,43 @@ class MultiProcessPrefetchData(ProxyDataFlow): ...@@ -170,33 +173,43 @@ class MultiProcessPrefetchData(ProxyDataFlow):
for dp in self.ds: for dp in self.ds:
self.queue.put(dp) self.queue.put(dp)
def __init__(self, ds, nr_prefetch, nr_proc): def __init__(self, ds, num_prefetch=None, num_proc=None, nr_prefetch=None, nr_proc=None):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
nr_prefetch (int): size of the queue to hold prefetched datapoints. num_prefetch (int): size of the queue to hold prefetched datapoints.
nr_proc (int): number of processes to use. num_proc (int): number of processes to use.
nr_prefetch, nr_proc: deprecated argument names
""" """
if nr_prefetch is not None:
log_deprecated("MultiProcessRunner(nr_prefetch)", "Renamed to 'num_prefetch'", "2020-01-01")
num_prefetch = nr_prefetch
if nr_proc is not None:
log_deprecated("MultiProcessRunner(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#the-spawn-and-forkserver-start-methods # https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#the-spawn-and-forkserver-start-methods
if os.name == 'nt': if os.name == 'nt':
logger.warn("MultiProcessPrefetchData does support Windows. \ logger.warn("MultiProcessRunner does support Windows. \
However, Windows requires more strict picklability on processes, which may \ However, Windows requires more strict picklability on processes, which may \
lead of failure on some of the code.") lead of failure on some of the code.")
super(MultiProcessPrefetchData, self).__init__(ds) super(MultiProcessRunner, self).__init__(ds)
try: try:
self._size = len(ds) self._size = len(ds)
except NotImplementedError: except NotImplementedError:
self._size = -1 self._size = -1
self.nr_proc = nr_proc assert num_proc > 0, num_proc
self.nr_prefetch = nr_prefetch assert num_prefetch > 0, num_prefetch
self.num_proc = num_proc
self.num_prefetch = num_prefetch
if nr_proc > 1: if num_proc > 1:
logger.info("[MultiProcessPrefetchData] Will fork a dataflow more than one times. " logger.info("[MultiProcessRunner] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.") "This assumes the datapoints are i.i.d.")
self.queue = mp.Queue(self.nr_prefetch) self.queue = mp.Queue(self.num_prefetch)
self.procs = [MultiProcessPrefetchData._Worker(self.ds, self.queue, idx) self.procs = [MultiProcessRunner._Worker(self.ds, self.queue, idx)
for idx in range(self.nr_proc)] for idx in range(self.num_proc)]
ensure_proc_terminate(self.procs) ensure_proc_terminate(self.procs)
start_proc_mask_signal(self.procs) start_proc_mask_signal(self.procs)
...@@ -212,31 +225,28 @@ lead of failure on some of the code.") ...@@ -212,31 +225,28 @@ lead of failure on some of the code.")
pass pass
PrefetchData = MultiProcessPrefetchData class MultiProcessRunnerZMQ(_MultiProcessZMQDataFlow):
# TODO renamed to MultiProcessDataFlow{,ZMQ} if separated to a new project
class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
""" """
Prefetch data from a DataFlow using multiple processes, with ZeroMQ for communication. Run a DataFlow in >=1 processes, with ZeroMQ for communication.
It will fork the calling process of :meth:`reset_state()`, It will fork the calling process of :meth:`reset_state()`,
and collect datapoints from the given dataflow in each process by ZeroMQ IPC pipe. and collect datapoints from the given dataflow in each process by ZeroMQ IPC pipe.
This is typically faster than :class:`MultiProcessRunner`.
Note: Note:
1. (Data integrity) An iterator cannot run faster automatically -- what's happening is 1. (Data integrity) An iterator cannot run faster automatically -- what's happening is
that the process will be forked ``nr_proc`` times. that the process will be forked ``num_proc`` times.
There will be ``nr_proc`` dataflow running in parallel and **independently**. There will be ``num_proc`` dataflow running in parallel and **independently**.
As a result, we have the following guarantee on the dataflow correctness: As a result, we have the following guarantee on the dataflow correctness:
a. When ``nr_proc=1``, this dataflow produces the same data as the a. When ``num_proc=1``, this dataflow produces the same data as the
given dataflow in the same order. given dataflow in the same order.
b. When ``nr_proc>1``, if each sample from the given dataflow is i.i.d., b. When ``num_proc>1``, if each sample from the given dataflow is i.i.d.,
then this dataflow produces the **same distribution** of data as the given dataflow. then this dataflow produces the **same distribution** of data as the given dataflow.
This implies that there will be duplication, reordering, etc. This implies that there will be duplication, reordering, etc.
You probably only want to use it for training. You probably only want to use it for training.
For example, if your original dataflow contains no randomness and produces the same first datapoint, For example, if your original dataflow contains no randomness and produces the same first datapoint,
then after parallel prefetching, the datapoint will be produced ``nr_proc`` times then after parallel prefetching, the datapoint will be produced ``num_proc`` times
at the beginning. at the beginning.
Even when your original dataflow is fully shuffled, you still need to be aware of the Even when your original dataflow is fully shuffled, you still need to be aware of the
`Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_ `Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_
...@@ -251,10 +261,10 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -251,10 +261,10 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
it's better to fork before creating the session. it's better to fork before creating the session.
4. (Fork-safety) After the fork has happened, this dataflow becomes not fork-safe. 4. (Fork-safety) After the fork has happened, this dataflow becomes not fork-safe.
i.e., if you fork an already reset instance of this dataflow, i.e., if you fork an already reset instance of this dataflow,
it won't be usable in the forked process. Therefore, do not nest two `PrefetchDataZMQ`. it won't be usable in the forked process. Therefore, do not nest two `MultiProcessRunnerZMQ`.
5. (Thread-safety) ZMQ is not thread safe. Therefore, do not call :meth:`get_data` of the same dataflow in 5. (Thread-safety) ZMQ is not thread safe. Therefore, do not call :meth:`get_data` of the same dataflow in
more than 1 threads. more than 1 threads.
6. This dataflow does not support windows. Use `MultiProcessPrefetchData` which works on windows. 6. This dataflow does not support windows. Use `MultiProcessRunner` which works on windows.
7. (For Mac only) A UNIX named pipe will be created in the current directory. 7. (For Mac only) A UNIX named pipe will be created in the current directory.
However, certain non-local filesystem such as NFS/GlusterFS/AFS doesn't always support pipes. However, certain non-local filesystem such as NFS/GlusterFS/AFS doesn't always support pipes.
You can change the directory by ``export TENSORPACK_PIPEDIR=/other/dir``. You can change the directory by ``export TENSORPACK_PIPEDIR=/other/dir``.
...@@ -269,7 +279,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -269,7 +279,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
class _Worker(mp.Process): class _Worker(mp.Process):
def __init__(self, ds, conn_name, hwm, idx): def __init__(self, ds, conn_name, hwm, idx):
super(PrefetchDataZMQ._Worker, self).__init__() super(MultiProcessRunnerZMQ._Worker, self).__init__()
self.ds = ds self.ds = ds
self.conn_name = conn_name self.conn_name = conn_name
self.hwm = hwm self.hwm = hwm
...@@ -293,21 +303,25 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -293,21 +303,25 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
socket.close(0) socket.close(0)
context.destroy(0) context.destroy(0)
def __init__(self, ds, nr_proc=1, hwm=50): def __init__(self, ds, num_proc=1, hwm=50, nr_proc=None):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
nr_proc (int): number of processes to use. num_proc (int): number of processes to use.
hwm (int): the zmq "high-water mark" (queue size) for both sender and receiver. hwm (int): the zmq "high-water mark" (queue size) for both sender and receiver.
nr_proc: deprecated
""" """
super(PrefetchDataZMQ, self).__init__() if nr_proc is not None:
log_deprecated("MultiProcessRunnerZMQ(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
super(MultiProcessRunnerZMQ, self).__init__()
self.ds = ds self.ds = ds
self.nr_proc = nr_proc self.num_proc = num_proc
self._hwm = hwm self._hwm = hwm
if nr_proc > 1: if num_proc > 1:
logger.info("[PrefetchDataZMQ] Will fork a dataflow more than one times. " logger.info("[MultiProcessRunnerZMQ] Will fork a dataflow more than one times. "
"This assumes the datapoints are i.i.d.") "This assumes the datapoints are i.i.d.")
try: try:
self._size = ds.__len__() self._size = ds.__len__()
...@@ -321,14 +335,14 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -321,14 +335,14 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
return self.ds.__len__() return self.ds.__len__()
def __iter__(self): def __iter__(self):
with self._guard, _zmq_catch_error('PrefetchDataZMQ'): with self._guard, _zmq_catch_error('MultiProcessRunnerZMQ'):
for k in itertools.count(): for k in itertools.count():
if self._size > 0 and k >= self._size: if self._size > 0 and k >= self._size:
break break
yield self._recv() yield self._recv()
def reset_state(self): def reset_state(self):
super(PrefetchDataZMQ, self).reset_state() super(MultiProcessRunnerZMQ, self).reset_state()
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL) self.socket = self.context.socket(zmq.PULL)
...@@ -336,32 +350,31 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -336,32 +350,31 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
pipename = _get_pipe_name('dataflow') pipename = _get_pipe_name('dataflow')
_bind_guard(self.socket, pipename) _bind_guard(self.socket, pipename)
self._procs = [PrefetchDataZMQ._Worker(self.ds, pipename, self._hwm, idx) self._procs = [MultiProcessRunnerZMQ._Worker(self.ds, pipename, self._hwm, idx)
for idx in range(self.nr_proc)] for idx in range(self.num_proc)]
self._start_processes() self._start_processes()
# TODO renamed to MultiThreadDataFlow if separated to a new project class MultiThreadRunner(DataFlow):
class MultiThreadPrefetchData(DataFlow):
""" """
Create multiple dataflow instances and run them each in one thread. Create multiple dataflow instances and run them each in one thread.
Collect outputs with a queue. Collect outputs from them with a queue.
Note: Note:
1. (Data integrity) An iterator cannot run faster automatically -- what's happening is 1. (Data integrity) An iterator cannot run faster automatically -- what's happening is
that each thread will create a dataflow iterator. that each thread will create a dataflow iterator.
There will be ``nr_thread`` dataflow running in parallel and **independently**. There will be ``num_thread`` dataflow running in parallel and **independently**.
As a result, we have the following guarantee on the dataflow correctness: As a result, we have the following guarantee on the dataflow correctness:
a. When ``nr_thread=1``, this dataflow produces the same data as the a. When ``num_thread=1``, this dataflow produces the same data as the
given dataflow in the same order. given dataflow in the same order.
b. When ``nr_thread>1``, if each sample from the given dataflow is i.i.d., b. When ``num_thread>1``, if each sample from the given dataflow is i.i.d.,
then this dataflow produces the **same distribution** of data as the given dataflow. then this dataflow produces the **same distribution** of data as the given dataflow.
This implies that there will be duplication, reordering, etc. This implies that there will be duplication, reordering, etc.
You probably only want to use it for training. You probably only want to use it for training.
For example, if your original dataflow contains no randomness and produces the same first datapoint, For example, if your original dataflow contains no randomness and produces the same first datapoint,
then after parallel prefetching, the datapoint will be produced ``nr_thread`` times then after parallel prefetching, the datapoint will be produced ``num_thread`` times
at the beginning. at the beginning.
Even when your original dataflow is fully shuffled, you still need to be aware of the Even when your original dataflow is fully shuffled, you still need to be aware of the
`Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_ `Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_
...@@ -373,7 +386,7 @@ class MultiThreadPrefetchData(DataFlow): ...@@ -373,7 +386,7 @@ class MultiThreadPrefetchData(DataFlow):
class _Worker(StoppableThread): class _Worker(StoppableThread):
def __init__(self, get_df, queue): def __init__(self, get_df, queue):
super(MultiThreadPrefetchData._Worker, self).__init__() super(MultiThreadRunner._Worker, self).__init__()
self.df = get_df() self.df = get_df()
assert isinstance(self.df, DataFlow), self.df assert isinstance(self.df, DataFlow), self.df
self.queue = queue self.queue = queue
...@@ -395,21 +408,29 @@ class MultiThreadPrefetchData(DataFlow): ...@@ -395,21 +408,29 @@ class MultiThreadPrefetchData(DataFlow):
finally: finally:
self.stop() self.stop()
def __init__(self, get_df, nr_prefetch, nr_thread): def __init__(self, get_df, num_prefetch=None, num_thread=None, nr_prefetch=None, nr_thread=None):
""" """
Args: Args:
get_df ( -> DataFlow): a callable which returns a DataFlow. get_df ( -> DataFlow): a callable which returns a DataFlow.
Each thread will call this function to get the DataFlow to use. Each thread will call this function to get the DataFlow to use.
Therefore do not return the same DataFlow for each call. Therefore do not return the same DataFlow for each call.
nr_prefetch (int): size of the queue num_prefetch (int): size of the queue
nr_thread (int): number of threads num_thread (int): number of threads
nr_prefetch, nr_thread: deprecated names
""" """
assert nr_thread > 0 and nr_prefetch > 0 if nr_prefetch is not None:
self.nr_thread = nr_thread log_deprecated("MultiThreadRunner(nr_prefetch)", "Renamed to 'num_prefetch'", "2020-01-01")
self.queue = queue.Queue(maxsize=nr_prefetch) num_prefetch = nr_prefetch
if nr_thread is not None:
log_deprecated("MultiThreadRunner(nr_thread)", "Renamed to 'num_thread'", "2020-01-01")
num_thread = nr_thread
assert num_thread > 0, num_thread
assert num_prefetch > 0, num_prefetch
self.num_thread = num_thread
self.queue = queue.Queue(maxsize=num_prefetch)
self.threads = [ self.threads = [
MultiThreadPrefetchData._Worker(get_df, self.queue) MultiThreadRunner._Worker(get_df, self.queue)
for _ in range(nr_thread)] for _ in range(num_thread)]
def reset_state(self): def reset_state(self):
for th in self.threads: for th in self.threads:
...@@ -482,13 +503,19 @@ plasma = None ...@@ -482,13 +503,19 @@ plasma = None
# PlasmaGetData = create_dummy_class('PlasmaGetData', 'pyarrow') # noqa # PlasmaGetData = create_dummy_class('PlasmaGetData', 'pyarrow') # noqa
# The old inappropriate names:
PrefetchData = MultiProcessRunner
MultiProcessPrefetchData = MultiProcessRunner
PrefetchDataZMQ = MultiProcessRunnerZMQ
MultiThreadPrefetchData = MultiThreadRunner
if __name__ == '__main__': if __name__ == '__main__':
import time import time
from .raw import DataFromGenerator from .raw import DataFromGenerator
from .common import FixedSizeData from .common import FixedSizeData
x = DataFromGenerator(itertools.count()) x = DataFromGenerator(itertools.count())
x = FixedSizeData(x, 100) x = FixedSizeData(x, 100)
x = PrefetchDataZMQ(x, 2) x = MultiProcessRunnerZMQ(x, 2)
x.reset_state() x.reset_state()
for idx, dp in enumerate(x): for idx, dp in enumerate(x):
print(dp) print(dp)
......
...@@ -10,11 +10,12 @@ from six.moves import queue ...@@ -10,11 +10,12 @@ from six.moves import queue
from ..utils.concurrency import StoppableThread, enable_death_signal from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils.serialize import dumps, loads from ..utils.serialize import dumps, loads
from ..utils.develop import log_deprecated
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow
from .common import RepeatedData from .common import RepeatedData
from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error
__all__ = ['ThreadedMapData', 'MultiThreadMapData', __all__ = ['MultiThreadMapData',
'MultiProcessMapData', 'MultiProcessMapDataZMQ'] 'MultiProcessMapData', 'MultiProcessMapDataZMQ']
...@@ -115,7 +116,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -115,7 +116,7 @@ class MultiThreadMapData(_ParallelMapData):
1. You should avoid starting many threads in your main process to reduce GIL contention. 1. You should avoid starting many threads in your main process to reduce GIL contention.
The threads will only start in the process which calls :meth:`reset_state()`. The threads will only start in the process which calls :meth:`reset_state()`.
Therefore you can use ``PrefetchDataZMQ(MultiThreadMapData(...), 1)`` Therefore you can use ``MultiProcessRunnerZMQ(MultiThreadMapData(...), 1)``
to reduce GIL contention. to reduce GIL contention.
""" """
class _Worker(StoppableThread): class _Worker(StoppableThread):
...@@ -143,16 +144,21 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -143,16 +144,21 @@ class MultiThreadMapData(_ParallelMapData):
finally: finally:
self.stop() self.stop()
def __init__(self, ds, nr_thread, map_func, buffer_size=200, strict=False): def __init__(self, ds, num_thread=None, map_func=None, buffer_size=200, strict=False, nr_thread=None):
""" """
Args: Args:
ds (DataFlow): the dataflow to map ds (DataFlow): the dataflow to map
nr_thread (int): number of threads to use num_thread (int): number of threads to use
map_func (callable): datapoint -> datapoint | None. Return None to map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint. discard/skip the datapoint.
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
nr_thread: deprecated name
""" """
if nr_thread is not None:
log_deprecated("MultiThreadMapData(nr_thread)", "Renamed to 'num_thread'", "2020-01-01")
num_thread = nr_thread
if strict: if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints # In strict mode, buffer size cannot be larger than the total number of datapoints
try: try:
...@@ -161,10 +167,10 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -161,10 +167,10 @@ class MultiThreadMapData(_ParallelMapData):
pass pass
super(MultiThreadMapData, self).__init__(ds, buffer_size, strict) super(MultiThreadMapData, self).__init__(ds, buffer_size, strict)
assert nr_thread > 0, nr_thread assert num_thread > 0, num_thread
self._strict = strict self._strict = strict
self.nr_thread = nr_thread self.num_thread = num_thread
self.map_func = map_func self.map_func = map_func
self._threads = [] self._threads = []
self._evt = None self._evt = None
...@@ -181,7 +187,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -181,7 +187,7 @@ class MultiThreadMapData(_ParallelMapData):
self._evt = threading.Event() self._evt = threading.Event()
self._threads = [MultiThreadMapData._Worker( self._threads = [MultiThreadMapData._Worker(
self._in_queue, self._out_queue, self._evt, self.map_func) self._in_queue, self._out_queue, self._evt, self.map_func)
for _ in range(self.nr_thread)] for _ in range(self.num_thread)]
for t in self._threads: for t in self._threads:
t.start() t.start()
...@@ -211,10 +217,6 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -211,10 +217,6 @@ class MultiThreadMapData(_ParallelMapData):
# logger.warn("Cannot join thread {}.".format(p.name)) # logger.warn("Cannot join thread {}.".format(p.name))
# TODO deprecated
ThreadedMapData = MultiThreadMapData
class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
""" """
Same as :class:`MapData`, but start processes to run the mapping function, Same as :class:`MapData`, but start processes to run the mapping function,
...@@ -255,16 +257,20 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -255,16 +257,20 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
dp = self.map_func(dp) dp = self.map_func(dp)
socket.send(dumps(dp), copy=False) socket.send(dumps(dp), copy=False)
def __init__(self, ds, nr_proc, map_func, buffer_size=200, strict=False): def __init__(self, ds, num_proc=None, map_func=None, buffer_size=200, strict=False, nr_proc=None):
""" """
Args: Args:
ds (DataFlow): the dataflow to map ds (DataFlow): the dataflow to map
nr_proc(int): number of threads to use num_proc(int): number of threads to use
map_func (callable): datapoint -> datapoint | None. Return None to map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint. discard/skip the datapoint.
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
nr_proc: deprecated name
""" """
if nr_proc is not None:
log_deprecated("MultiProcessMapDataZMQ(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
if strict: if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints # In strict mode, buffer size cannot be larger than the total number of datapoints
try: try:
...@@ -274,8 +280,8 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -274,8 +280,8 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
_ParallelMapData.__init__(self, ds, buffer_size, strict) _ParallelMapData.__init__(self, ds, buffer_size, strict)
_MultiProcessZMQDataFlow.__init__(self) _MultiProcessZMQDataFlow.__init__(self)
assert nr_proc > 0, nr_proc assert num_proc > 0, num_proc
self.nr_proc = nr_proc self.num_proc = num_proc
self.map_func = map_func self.map_func = map_func
self._strict = strict self._strict = strict
self._procs = [] self._procs = []
...@@ -291,11 +297,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -291,11 +297,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
pipename = _get_pipe_name('dataflow-map') pipename = _get_pipe_name('dataflow-map')
_bind_guard(self.socket, pipename) _bind_guard(self.socket, pipename)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.nr_proc)] self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.num_proc)]
worker_hwm = int(self._buffer_size * 2 // self.nr_proc) worker_hwm = int(self._buffer_size * 2 // self.num_proc)
self._procs = [MultiProcessMapDataZMQ._Worker( self._procs = [MultiProcessMapDataZMQ._Worker(
self._proc_ids[k], self.map_func, pipename, worker_hwm) self._proc_ids[k], self.map_func, pipename, worker_hwm)
for k in range(self.nr_proc)] for k in range(self.num_proc)]
self._start_processes() self._start_processes()
self._fill_buffer() # pre-fill the bufer self._fill_buffer() # pre-fill the bufer
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment