Commit dd2d9ffa authored by Yuxin Wu's avatar Yuxin Wu

Rename PrefetchData -> MultiProcessRunner

parent 0cecfbb6
...@@ -372,7 +372,6 @@ _DEPRECATED_NAMES = set([ ...@@ -372,7 +372,6 @@ _DEPRECATED_NAMES = set([
# deprecated stuff: # deprecated stuff:
'QueueInputTrainer', 'QueueInputTrainer',
'dump_dataflow_to_process_queue', 'dump_dataflow_to_process_queue',
'PrefetchOnGPUs',
'DistributedTrainerReplicated', 'DistributedTrainerReplicated',
'DistributedTrainerParameterServer', 'DistributedTrainerParameterServer',
'InputDesc', 'InputDesc',
...@@ -382,11 +381,14 @@ _DEPRECATED_NAMES = set([ ...@@ -382,11 +381,14 @@ _DEPRECATED_NAMES = set([
'DumpTensor', 'DumpTensor',
'DumpParamAsImage', 'DumpParamAsImage',
'get_nr_gpu', 'get_nr_gpu',
'start_test', # TestDataSpeed
'ThreadedMapData',
'TrainingMonitor', 'TrainingMonitor',
'PeakMemoryTracker', 'PeakMemoryTracker',
'PrefetchData',
'MultiProcessPrefetchData',
'PrefetchDataZMQ',
'MultiThreadPrefetchData',
# deprecated or renamed symbolic code # deprecated or renamed symbolic code
'Deconv2D', 'psnr', 'Deconv2D', 'psnr',
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
### What is DataFlow ### What is DataFlow
DataFlow is a library to build Python iterators for efficient data loading. DataFlow is a pure-Python library to create iterators for efficient data loading.
**Definition**: A DataFlow is a idiomatic Python container object that has a `__iter__()` generator method, **Definition**: A DataFlow is a idiomatic Python iterator object that has a `__iter__()` method
which yields `datapoints` and optionally a `__len__()` method returning the size of the flow. which yields `datapoints`, and optionally a `__len__()` method returning the size of the DataFlow.
A datapoint is a **list** of Python objects which are called the `components` of a datapoint. A datapoint is a **list or dict** of Python objects, each of which are called the `components` of a datapoint.
**Example**: to train on MNIST dataset, you may need a DataFlow with a `__iter__()` method **Example**: to train on MNIST dataset, you may need a DataFlow with a `__iter__()` method
that yields datapoints (lists) of two components: that yields datapoints (lists) of two components:
...@@ -21,12 +21,10 @@ You can simply use DataFlow as a data processing pipeline and plug it into any o ...@@ -21,12 +21,10 @@ You can simply use DataFlow as a data processing pipeline and plug it into any o
### Composition of DataFlow ### Composition of DataFlow
One good thing about having a standard interface is to be able to provide
the greatest code reusability.
There are a lot of existing DataFlow utilities in tensorpack, which you can use to compose There are a lot of existing DataFlow utilities in tensorpack, which you can use to compose
DataFlow with complex data pipeline. A common pipeline usually one DataFlow with complex data pipeline. A common pipeline usually
would __read from disk (or other sources), apply transformations, group into batches, would __read from disk (or other sources), apply transformations (possibly in parallel), group into batches,
prefetch data__, etc. A simple example is as the following: prefetch data__, etc, and all __run in parallel__. A simple example is as the following:
````python ````python
# a DataFlow you implement to produce [tensor1, tensor2, ..] lists from whatever sources: # a DataFlow you implement to produce [tensor1, tensor2, ..] lists from whatever sources:
...@@ -36,17 +34,17 @@ df = AugmentImageComponent(df, [imgaug.Resize((225, 225))]) ...@@ -36,17 +34,17 @@ df = AugmentImageComponent(df, [imgaug.Resize((225, 225))])
# group data into batches of size 128 # group data into batches of size 128
df = BatchData(df, 128) df = BatchData(df, 128)
# start 3 processes to run the dataflow in parallel # start 3 processes to run the dataflow in parallel
df = PrefetchDataZMQ(df, 3) df = MultiProcessRunnerZMQ(df, 3)
```` ````
You can find more complicated DataFlow in the [ImageNet training script](../examples/ImageNetModels/imagenet_utils.py) You can find more complicated DataFlow in the [ImageNet training script](../examples/ImageNetModels/imagenet_utils.py)
with all the data preprocessing. with all the data preprocessing.
### Work with Your Data ### Work with Your Data
Unless you are working with standard data types (image folders, LMDB, etc), We do not make any assumptions about your data format.
you would usually want to write the source DataFlow (`MyDataFlow` in the above example) for your data format. You would usually want to write the source DataFlow (`MyDataFlow` in the above example) for your own data format.
See [another tutorial](extend/dataflow.html) for simple instructions on writing a DataFlow. See [another tutorial](extend/dataflow.html) for simple instructions on writing a DataFlow.
Once you have the source reader, all the [existing Once you have the source reader, all the [built-in
DataFlows](../modules/dataflow.html) are ready for you to build up the rest of the data pipeline. DataFlows](../modules/dataflow.html) are ready for you to assemble the rest of the data pipeline.
### Why DataFlow ### Why DataFlow
...@@ -62,16 +60,16 @@ Nevertheless, tensorpack supports data loading with native TF operators / TF dat ...@@ -62,16 +60,16 @@ Nevertheless, tensorpack supports data loading with native TF operators / TF dat
### Use DataFlow in Your Own Code ### Use DataFlow in Your Own Code
Normally, tensorpack `InputSource` interface runs the DataFlow during training. When training with tensorpack, typically it is the `InputSource` interface that runs the DataFlow.
However, DataFlow can also be used without other tensorpack components. However, DataFlow can be used without other tensorpack components.
If you need to run the DataFlow by yourself, call `reset_state()` first to initialize it, To run a DataFlow by yourself, call `reset_state()` first to initialize it,
and then use the generator however you like: and then use the generator however you like:
```python ```python
df = SomeDataFlow() df = SomeDataFlow()
df.reset_state() df.reset_state()
for dp in df: for dp in df:
# dp is now a list. do whatever # dp is now a list. do whatever
``` ```
Read the [API documentation](../../modules/dataflow.html#tensorpack.dataflow.DataFlw) Read the [API documentation](../../modules/dataflow.html#tensorpack.dataflow.DataFlw)
......
...@@ -16,7 +16,7 @@ then apply complicated preprocessing to it. ...@@ -16,7 +16,7 @@ then apply complicated preprocessing to it.
We aim to reach a speed of, roughly **1k~3k images per second**, to keep GPUs busy. We aim to reach a speed of, roughly **1k~3k images per second**, to keep GPUs busy.
Some things to know before reading: Some things to know before reading:
1. For smaller datasets (e.g. several GBs of images with lightweight preprocessing), a simple reader plus some multiprocess prefetch should usually work well enough. 1. For smaller datasets (e.g. several GBs of images with lightweight preprocessing), a simple reader plus some multiprocess runner should usually work well enough.
Therefore you don't have to understand this tutorial in depth unless you really find your data being the bottleneck. Therefore you don't have to understand this tutorial in depth unless you really find your data being the bottleneck.
This tutorial could be a bit complicated for people new to system architectures, but you do need these to be able to run fast enough on ImageNet-scale dataset. This tutorial could be a bit complicated for people new to system architectures, but you do need these to be able to run fast enough on ImageNet-scale dataset.
2. Having a fast Python generator **alone** may or may not improve your overall training speed. 2. Having a fast Python generator **alone** may or may not improve your overall training speed.
...@@ -64,7 +64,7 @@ On a good filesystem you probably can already observe good speed here (e.g. 5 it ...@@ -64,7 +64,7 @@ On a good filesystem you probably can already observe good speed here (e.g. 5 it
because we are doing heavy random read on the filesystem (regardless of whether `shuffle` is True). because we are doing heavy random read on the filesystem (regardless of whether `shuffle` is True).
Image decoding in `cv2.imread` could also be a bottleneck at this early stage. Image decoding in `cv2.imread` could also be a bottleneck at this early stage.
### Parallel Prefetch ### Parallel Runner
We will now add the cheapest pre-processing now to get an ndarray in the end instead of a list We will now add the cheapest pre-processing now to get an ndarray in the end instead of a list
(because training will need ndarray eventually): (because training will need ndarray eventually):
...@@ -84,15 +84,15 @@ Now it's time to add threads or processes: ...@@ -84,15 +84,15 @@ Now it's time to add threads or processes:
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True) ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True)
ds1 = AugmentImageComponent(ds0, lots_of_augmentors) ds1 = AugmentImageComponent(ds0, lots_of_augmentors)
ds = PrefetchDataZMQ(ds1, nr_proc=25) ds = MultiProcessRunnerZMQ(ds1, num_proc=25)
ds = BatchData(ds, 256) ds = BatchData(ds, 256)
``` ```
Here we fork 25 processes to run `ds1`, and collect their output through ZMQ IPC protocol, Here we fork 25 processes to run `ds1`, and collect their output through ZMQ IPC protocol,
which is faster than `multiprocessing.Queue`. You can also apply prefetch after batch, of course. which is faster than `multiprocessing.Queue`. You can also apply parallel runner after batching, of course.
### Parallel Map ### Parallel Map
The above DataFlow might be fast, but since it forks the ImageNet reader (`ds0`), The above DataFlow might be fast, but since it forks the ImageNet reader (`ds0`),
it's **not a good idea to use it for validation** (for reasons mentioned at top. More details at the [documentation](../modules/dataflow.html#tensorpack.dataflow.PrefetchDataZMQ)). it's **not a good idea to use it for validation** (for reasons mentioned at top. More details at the [documentation](../modules/dataflow.html#tensorpack.dataflow.MultiProcessRunnerZMQ)).
Alternatively, you can use multi-threaded preprocessing like this: Alternatively, you can use multi-threaded preprocessing like this:
```eval_rst ```eval_rst
...@@ -102,9 +102,9 @@ Alternatively, you can use multi-threaded preprocessing like this: ...@@ -102,9 +102,9 @@ Alternatively, you can use multi-threaded preprocessing like this:
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True) ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors) augmentor = AugmentorList(lots_of_augmentors)
ds1 = MultiThreadMapData( ds1 = MultiThreadMapData(
ds0, nr_thread=25, ds0, num_thread=25,
map_func=lambda dp: [augmentor.augment(dp[0]), dp[1]], buffer_size=1000) map_func=lambda dp: [augmentor.augment(dp[0]), dp[1]], buffer_size=1000)
# ds1 = PrefetchDataZMQ(ds1, nr_proc=1) # ds1 = MultiProcessRunnerZMQ(ds1, num_proc=1)
ds = BatchData(ds1, 256) ds = BatchData(ds1, 256)
``` ```
`MultiThreadMapData` launches a thread pool to fetch data and apply the mapping function on **a single `MultiThreadMapData` launches a thread pool to fetch data and apply the mapping function on **a single
...@@ -127,11 +127,11 @@ If you identify this as a bottleneck, you can also use: ...@@ -127,11 +127,11 @@ If you identify this as a bottleneck, you can also use:
ds0 = dataset.ILSVRC12Files('/path/to/ILSVRC12', 'train', shuffle=True) ds0 = dataset.ILSVRC12Files('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors) augmentor = AugmentorList(lots_of_augmentors)
ds1 = MultiThreadMapData( ds1 = MultiThreadMapData(
ds0, nr_thread=25, ds0, num_thread=25,
map_func=lambda dp: map_func=lambda dp:
[augmentor.augment(cv2.imread(dp[0], cv2.IMREAD_COLOR)), dp[1]], [augmentor.augment(cv2.imread(dp[0], cv2.IMREAD_COLOR)), dp[1]],
buffer_size=1000) buffer_size=1000)
ds1 = PrefetchDataZMQ(ds1, nr_proc=1) ds1 = MultiProcessRunnerZMQ(ds1, num_proc=1)
ds = BatchData(ds1, 256) ds = BatchData(ds1, 256)
``` ```
...@@ -159,15 +159,15 @@ class BinaryILSVRC12(dataset.ILSVRC12Files): ...@@ -159,15 +159,15 @@ class BinaryILSVRC12(dataset.ILSVRC12Files):
jpeg = np.asarray(bytearray(jpeg), dtype='uint8') jpeg = np.asarray(bytearray(jpeg), dtype='uint8')
yield [jpeg, label] yield [jpeg, label]
ds0 = BinaryILSVRC12('/path/to/ILSVRC/', 'train') ds0 = BinaryILSVRC12('/path/to/ILSVRC/', 'train')
ds1 = PrefetchDataZMQ(ds0, nr_proc=1) ds1 = MultiProcessRunnerZMQ(ds0, num_proc=1)
LMDBSerializer.save(ds1, '/path/to/ILSVRC-train.lmdb') LMDBSerializer.save(ds1, '/path/to/ILSVRC-train.lmdb')
``` ```
The above script builds a DataFlow which produces jpeg-encoded ImageNet data. The above script builds a DataFlow which produces jpeg-encoded ImageNet data.
We store the jpeg string as a numpy array because the function `cv2.imdecode` later expect this format. We store the jpeg string as a numpy array because the function `cv2.imdecode` later expect this format.
Please note we can only use 1 prefetch process to speed up. If `nr_proc>1`, `ds1` will take data Please note we can only use 1 runner process to speed up. If `num_proc>1`, `ds1` will take data
from several forks of `ds0`, then neither the content nor the order of `ds1` will be the same as `ds0`. from several forks of `ds0`, then neither the content nor the order of `ds1` will be the same as `ds0`.
See [documentation](../modules/dataflow.html#tensorpack.dataflow.PrefetchDataZMQ) See [documentation](../modules/dataflow.html#tensorpack.dataflow.MultiProcessRunnerZMQ)
about caveats of `PrefetchDataZMQ`. about caveats of `MultiProcessRunnerZMQ`.
It will generate a database file of 140G. We load the DataFlow back by reading this LMDB file sequentially: It will generate a database file of 140G. We load the DataFlow back by reading this LMDB file sequentially:
``` ```
...@@ -193,7 +193,7 @@ the added line above maintains a buffer of datapoints and shuffle them once a wh ...@@ -193,7 +193,7 @@ the added line above maintains a buffer of datapoints and shuffle them once a wh
It will not affect the model as long as the buffer is large enough, It will not affect the model as long as the buffer is large enough,
but it can also consume much memory if too large. but it can also consume much memory if too large.
### Augmentations & Parallel Prefetch ### Augmentations & Parallel Runner
Then we add necessary transformations: Then we add necessary transformations:
```eval_rst ```eval_rst
...@@ -218,24 +218,24 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like ...@@ -218,24 +218,24 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like
ds = LMDBSerializer.load(db, shuffle=False) ds = LMDBSerializer.load(db, shuffle=False)
ds = LocallyShuffleData(ds, 50000) ds = LocallyShuffleData(ds, 50000)
ds = PrefetchData(ds, 5000, 1) ds = MultiProcessRunner(ds, 5000, 1)
ds = MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0) ds = MapDataComponent(ds, lambda x: cv2.imdecode(x, cv2.IMREAD_COLOR), 0)
ds = AugmentImageComponent(ds, lots_of_augmentors) ds = AugmentImageComponent(ds, lots_of_augmentors)
ds = PrefetchDataZMQ(ds, 25) ds = MultiProcessRunnerZMQ(ds, 25)
ds = BatchData(ds, 256) ds = BatchData(ds, 256)
``` ```
Since we are reading the database sequentially, having multiple forked instances of the Since we are reading the database sequentially, having multiple forked instances of the
base LMDB reader will result in biased data distribution. Therefore we use `PrefetchData` to base LMDB reader will result in biased data distribution. Therefore we use `MultiProcessRunner` to
launch the base DataFlow in only **one process**, and only parallelize the transformations launch the base DataFlow in only **one process**, and only parallelize the transformations
with another `PrefetchDataZMQ` with another `MultiProcessRunnerZMQ`
(Nesting two `PrefetchDataZMQ`, however, will result in a different behavior. (Nesting two `MultiProcessRunnerZMQ`, however, will result in a different behavior.
These differences are explained in the API documentation in more details.). These differences are explained in the API documentation in more details.).
Similar to what we did earlier, you can use `MultiThreadMapData` to parallelize as well. Similar to what we did earlier, you can use `MultiThreadMapData` to parallelize as well.
Let me summarize what this DataFlow does: Let me summarize what this DataFlow does:
1. One process reads LMDB file, shuffle them in a buffer and put them into a `multiprocessing.Queue` (used by `PrefetchData`). 1. One process reads LMDB file, shuffle them in a buffer and put them into a `multiprocessing.Queue` (used by `MultiProcessRunner`).
2. 25 processes take items from the queue, decode and process them into [image, label] pairs, and 2. 25 processes take items from the queue, decode and process them into [image, label] pairs, and
send them through ZMQ IPC pipe. send them through ZMQ IPC pipe.
3. The main process takes data from the pipe, makes batches. 3. The main process takes data from the pipe, makes batches.
......
...@@ -82,7 +82,7 @@ def get_data(path, isTrain, stat_file): ...@@ -82,7 +82,7 @@ def get_data(path, isTrain, stat_file):
ds = MapDataComponent(ds, lambda x: (x - mean) / std) ds = MapDataComponent(ds, lambda x: (x - mean) / std)
ds = TIMITBatch(ds, BATCH) ds = TIMITBatch(ds, BATCH)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, 1) ds = MultiProcessRunnerZMQ(ds, 1)
return ds return ds
......
...@@ -32,7 +32,7 @@ def get_data(): ...@@ -32,7 +32,7 @@ def get_data():
] ]
data_train = AugmentImageComponent(data_train, augmentors) data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128) data_train = BatchData(data_train, 128)
data_train = PrefetchData(data_train, 5, 5) data_train = MultiProcessRunner(data_train, 5, 5)
augmentors = [imgaug.Resize((40, 40))] augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors) data_test = AugmentImageComponent(data_test, augmentors)
......
...@@ -148,7 +148,7 @@ def get_config(): ...@@ -148,7 +148,7 @@ def get_config():
] ]
data_train = AugmentImageComponent(data_train, augmentors) data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128) data_train = BatchData(data_train, 128)
data_train = PrefetchDataZMQ(data_train, 5) data_train = MultiProcessRunnerZMQ(data_train, 5)
augmentors = [imgaug.Resize((40, 40))] augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors) data_test = AugmentImageComponent(data_test, augmentors)
......
...@@ -225,7 +225,7 @@ def get_data(): ...@@ -225,7 +225,7 @@ def get_data():
ds = ThetaImages(ds) ds = ThetaImages(ds)
ds = RepeatedData(ds, 50) # just pretend this dataset is bigger ds = RepeatedData(ds, 50) # just pretend this dataset is bigger
# this pre-computation is pretty heavy # this pre-computation is pretty heavy
ds = PrefetchDataZMQ(ds, min(20, multiprocessing.cpu_count())) ds = MultiProcessRunnerZMQ(ds, min(20, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH) ds = BatchData(ds, BATCH)
return ds return ds
......
...@@ -9,7 +9,7 @@ from tabulate import tabulate ...@@ -9,7 +9,7 @@ from tabulate import tabulate
from termcolor import colored from termcolor import colored
from tensorpack.dataflow import ( from tensorpack.dataflow import (
DataFromList, MapData, MapDataComponent, MultiProcessMapDataZMQ, MultiThreadMapData, DataFromList, MapData, MapDataComponent, MultiProcessMapData, MultiThreadMapData,
TestDataSpeed, imgaug) TestDataSpeed, imgaug)
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.argtools import log_once, memoized from tensorpack.utils.argtools import log_once, memoized
...@@ -392,7 +392,7 @@ def get_train_dataflow(): ...@@ -392,7 +392,7 @@ def get_train_dataflow():
# MPI does not like fork() # MPI does not like fork()
else: else:
buffer_size = cfg.DATA.NUM_WORKERS * 20 buffer_size = cfg.DATA.NUM_WORKERS * 20
ds = MultiProcessMapDataZMQ(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size) ds = MultiProcessMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
else: else:
ds = MapData(ds, preprocess) ds = MapData(ds, preprocess)
return ds return ds
......
...@@ -177,7 +177,7 @@ def get_data(datadir, isTrain=True): ...@@ -177,7 +177,7 @@ def get_data(datadir, isTrain=True):
names = ['trainA', 'trainB'] if isTrain else ['testA', 'testB'] names = ['trainA', 'trainB'] if isTrain else ['testA', 'testB']
df = get_image_pairs(*[os.path.join(datadir, n) for n in names]) df = get_image_pairs(*[os.path.join(datadir, n) for n in names])
df = BatchData(df, BATCH if isTrain else TEST_BATCH) df = BatchData(df, BATCH if isTrain else TEST_BATCH)
df = PrefetchDataZMQ(df, 2 if isTrain else 1) df = MultiProcessRunnerZMQ(df, 2 if isTrain else 1)
return df return df
......
...@@ -115,7 +115,7 @@ def get_data(): ...@@ -115,7 +115,7 @@ def get_data():
ds = ImageFromFile(imgs, channel=3, shuffle=True) ds = ImageFromFile(imgs, channel=3, shuffle=True)
ds = AugmentImageComponent(ds, get_augmentors()) ds = AugmentImageComponent(ds, get_augmentors())
ds = BatchData(ds, args.batch) ds = BatchData(ds, args.batch)
ds = PrefetchDataZMQ(ds, 5) ds = MultiProcessRunnerZMQ(ds, 5)
return ds return ds
......
...@@ -186,7 +186,7 @@ def get_celebA_data(datadir, styleA, styleB=None): ...@@ -186,7 +186,7 @@ def get_celebA_data(datadir, styleA, styleB=None):
imgaug.Resize(64)] imgaug.Resize(64)]
df = AugmentImageComponents(df, augs, (0, 1)) df = AugmentImageComponents(df, augs, (0, 1))
df = BatchData(df, BATCH) df = BatchData(df, BATCH)
df = PrefetchDataZMQ(df, 3) df = MultiProcessRunnerZMQ(df, 3)
return df return df
......
...@@ -173,7 +173,7 @@ def get_data(): ...@@ -173,7 +173,7 @@ def get_data():
augs = [imgaug.Resize(286), imgaug.RandomCrop(256)] augs = [imgaug.Resize(286), imgaug.RandomCrop(256)]
ds = AugmentImageComponents(ds, augs, (0, 1)) ds = AugmentImageComponents(ds, augs, (0, 1))
ds = BatchData(ds, BATCH) ds = BatchData(ds, BATCH)
ds = PrefetchData(ds, 100, 1) ds = MultiProcessRunner(ds, 100, 1)
return ds return ds
......
...@@ -233,7 +233,7 @@ def get_data(name): ...@@ -233,7 +233,7 @@ def get_data(name):
] ]
ds = AugmentImageComponent(ds, augmentors, copy=False) ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = BatchDataByShape(ds, 8, idx=0) ds = BatchDataByShape(ds, 8, idx=0)
ds = PrefetchDataZMQ(ds, 1) ds = MultiProcessRunnerZMQ(ds, 1)
else: else:
ds = BatchData(ds, 1) ds = BatchData(ds, 1)
return ds return ds
......
...@@ -11,7 +11,9 @@ import tensorflow as tf ...@@ -11,7 +11,9 @@ import tensorflow as tf
import tqdm import tqdm
from tensorpack import ModelDesc from tensorpack import ModelDesc
from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug from tensorpack.dataflow import (
AugmentImageComponent, BatchData, MultiThreadMapData,
MultiProcessRunnerZMQ, dataset, imgaug)
from tensorpack.input_source import QueueInput, StagingInput from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.models import regularize_cost, l2_regularizer from tensorpack.models import regularize_cost, l2_regularizer
from tensorpack.predict import FeedfreePredictor, PredictConfig from tensorpack.predict import FeedfreePredictor, PredictConfig
...@@ -88,7 +90,7 @@ def get_imagenet_dataflow( ...@@ -88,7 +90,7 @@ def get_imagenet_dataflow(
ds = AugmentImageComponent(ds, augmentors, copy=False) ds = AugmentImageComponent(ds, augmentors, copy=False)
if parallel < 16: if parallel < 16:
logger.warn("DataFlow may become the bottleneck when too few processes are used.") logger.warn("DataFlow may become the bottleneck when too few processes are used.")
ds = PrefetchDataZMQ(ds, parallel) ds = MultiProcessRunnerZMQ(ds, parallel)
ds = BatchData(ds, batch_size, remainder=False) ds = BatchData(ds, batch_size, remainder=False)
else: else:
ds = dataset.ILSVRC12Files(datadir, name, shuffle=False) ds = dataset.ILSVRC12Files(datadir, name, shuffle=False)
...@@ -101,7 +103,7 @@ def get_imagenet_dataflow( ...@@ -101,7 +103,7 @@ def get_imagenet_dataflow(
return im, cls return im, cls
ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True) ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
ds = BatchData(ds, batch_size, remainder=True) ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1) ds = MultiProcessRunnerZMQ(ds, 1)
return ds return ds
......
...@@ -133,7 +133,7 @@ def get_data(train_or_test): ...@@ -133,7 +133,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain) ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchData(ds, 3, 2) ds = MultiProcessRunner(ds, 3, 2)
return ds return ds
......
...@@ -68,7 +68,7 @@ def get_data(train_or_test): ...@@ -68,7 +68,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors, copy=False) ds = AugmentImageComponent(ds, augmentors, copy=False)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, min(25, multiprocessing.cpu_count())) ds = MultiProcessRunnerZMQ(ds, min(25, multiprocessing.cpu_count()))
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain) ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
return ds return ds
......
...@@ -254,7 +254,7 @@ def get_data(file_name): ...@@ -254,7 +254,7 @@ def get_data(file_name):
imgaug.Flip(horiz=True)] imgaug.Flip(horiz=True)]
ds = AugmentImageComponent(ds, augmentors, index=0, copy=True) ds = AugmentImageComponent(ds, augmentors, index=0, copy=True)
ds = MapData(ds, lambda x: [cv2.resize(x[0], (32, 32), interpolation=cv2.INTER_CUBIC), x[0]]) ds = MapData(ds, lambda x: [cv2.resize(x[0], (32, 32), interpolation=cv2.INTER_CUBIC), x[0]])
ds = PrefetchDataZMQ(ds, 3) ds = MultiProcessRunnerZMQ(ds, 3)
ds = BatchData(ds, BATCH_SIZE) ds = BatchData(ds, BATCH_SIZE)
return ds return ds
......
...@@ -103,7 +103,7 @@ def get_data(train_or_test, cifar_classnum): ...@@ -103,7 +103,7 @@ def get_data(train_or_test, cifar_classnum):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain) ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, 5) ds = MultiProcessRunnerZMQ(ds, 5)
return ds return ds
......
...@@ -78,7 +78,7 @@ def get_data(): ...@@ -78,7 +78,7 @@ def get_data():
] ]
data_train = AugmentImageComponent(data_train, augmentors) data_train = AugmentImageComponent(data_train, augmentors)
data_train = BatchData(data_train, 128) data_train = BatchData(data_train, 128)
data_train = PrefetchData(data_train, 5, 5) data_train = MultiProcessRunner(data_train, 5, 5)
augmentors = [imgaug.Resize((40, 40))] augmentors = [imgaug.Resize((40, 40))]
data_test = AugmentImageComponent(data_test, augmentors) data_test = AugmentImageComponent(data_test, augmentors)
......
...@@ -37,7 +37,7 @@ def get_data(subset): ...@@ -37,7 +37,7 @@ def get_data(subset):
# something that yields [[SHAPE, SHAPE, CHANNELS], [1]] # something that yields [[SHAPE, SHAPE, CHANNELS], [1]]
ds = FakeData([[SHAPE, SHAPE, CHANNELS], [1]], 1000, random=False, ds = FakeData([[SHAPE, SHAPE, CHANNELS], [1]], 1000, random=False,
dtype=['float32', 'uint8'], domain=[(0, 255), (0, 10)]) dtype=['float32', 'uint8'], domain=[(0, 255), (0, 10)])
ds = PrefetchDataZMQ(ds, 2) ds = MultiProcessRunnerZMQ(ds, 2)
ds = BatchData(ds, BATCH_SIZE) ds = BatchData(ds, BATCH_SIZE)
return ds return ds
......
...@@ -64,29 +64,25 @@ class DataFlow(object): ...@@ -64,29 +64,25 @@ class DataFlow(object):
@abstractmethod @abstractmethod
def __iter__(self): def __iter__(self):
""" """
* A dataflow is an iterable. The :meth:`__iter__` method should yield a list each time. * A dataflow is an iterable. The :meth:`__iter__` method should yield a list or dict each time.
Each element in the list should be either a number or a numpy array. Note that dict is **partially** supported at the moment: certain dataflow does not support dict.
For now, tensorpack also **partially** supports dict instead of list.
* The :meth:`__iter__` method can be either finite (will stop iteration) or infinite * The :meth:`__iter__` method can be either finite (will stop iteration) or infinite
(will not stop iteration). For a finite dataflow, :meth:`__iter__` can be called (will not stop iteration). For a finite dataflow, :meth:`__iter__` can be called
again after the previous call returned. again immediately after the previous call returned.
* For many dataflow, the :meth:`__iter__` method is non-reentrant, which means for an dataflow * For many dataflow, the :meth:`__iter__` method is non-reentrant, which means for an dataflow
instance ``df``, :meth:`df.__iter__` cannot be called before the previous instance ``df``, :meth:`df.__iter__` cannot be called before the previous
:meth:`df.__iter__` call has finished (iteration has stopped). :meth:`df.__iter__` call has finished (iteration has stopped).
When it is non-reentrant, :meth:`df.__iter__` should throw an exception if When a dataflow is non-reentrant, :meth:`df.__iter__` should throw an exception if
called before the previous call has finished. called before the previous call has finished.
For such non-reentrant dataflows, if you need to use the same dataflow in two places, For such non-reentrant dataflows, if you need to use the same dataflow in two places,
you need to create two dataflow instances. you need to create two dataflow instances.
Yields: Yields:
list: The datapoint, i.e. list of components. list/dict: The datapoint, i.e. list/dict of components.
""" """
def get_data(self):
return self.__iter__()
def __len__(self): def __len__(self):
""" """
* A dataflow can optionally implement :meth:`__len__`. If not implemented, it will * A dataflow can optionally implement :meth:`__len__`. If not implemented, it will
...@@ -95,7 +91,7 @@ class DataFlow(object): ...@@ -95,7 +91,7 @@ class DataFlow(object):
* It returns an integer representing the size of the dataflow. * It returns an integer representing the size of the dataflow.
The return value **may not be accurate or meaningful** at all. The return value **may not be accurate or meaningful** at all.
When saying the length is "accurate", it means that When saying the length is "accurate", it means that
:meth:`__iter__` will always yield this many of datapoints. :meth:`__iter__` will always yield this many of datapoints before it stops iteration.
* There could be many reasons why :meth:`__len__` is inaccurate. * There could be many reasons why :meth:`__len__` is inaccurate.
For example, some dataflow has dynamic size, if it throws away datapoints on the fly. For example, some dataflow has dynamic size, if it throws away datapoints on the fly.
...@@ -103,8 +99,9 @@ class DataFlow(object): ...@@ -103,8 +99,9 @@ class DataFlow(object):
the dataset, due to parallelism and buffering. the dataset, due to parallelism and buffering.
In this case it does not make sense to stop the iteration anywhere. In this case it does not make sense to stop the iteration anywhere.
* Due to the above reasons, the length is only a rough guidance. Inside * Due to the above reasons, the length is only a rough guidance.
tensorpack it's only used in these places: And it's up to the user how to interpret it.
Inside tensorpack it's only used in these places:
+ A default ``steps_per_epoch`` in training, but you probably want to customize + A default ``steps_per_epoch`` in training, but you probably want to customize
it yourself, especially when using data-parallel trainer. it yourself, especially when using data-parallel trainer.
...@@ -121,9 +118,6 @@ class DataFlow(object): ...@@ -121,9 +118,6 @@ class DataFlow(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def size(self):
return self.__len__()
def reset_state(self): def reset_state(self):
""" """
* The caller must guarantee that :meth:`reset_state` should be called **once and only once** * The caller must guarantee that :meth:`reset_state` should be called **once and only once**
...@@ -134,21 +128,28 @@ class DataFlow(object): ...@@ -134,21 +128,28 @@ class DataFlow(object):
e.g., initialize random number generators (RNG), create worker processes. e.g., initialize random number generators (RNG), create worker processes.
Because it's very common to use RNG in data processing, Because it's very common to use RNG in data processing,
developers of dataflow can also subclass :class:`RNGDataFlow` to have easier access to an RNG. developers of dataflow can also subclass :class:`RNGDataFlow` to have easier access to
a properly-initialized RNG.
* A dataflow is not fork-safe after :meth:`reset_state` is called (because this will violate the guarantee). * A dataflow is not fork-safe after :meth:`reset_state` is called (because this will violate the guarantee).
A few number of dataflow is not fork-safe anytime, which will be mentioned in the docs. There are a few other dataflows that are not fork-safe anytime, which will be mentioned in the docs.
* You should take the responsibility and follow the above guarantee if you're the caller of a dataflow yourself
(either when you're using dataflow outside of tensorpack, or if you're writing a wrapper dataflow).
* Tensorpack's built-in forking dataflows (:class:`MultiProcessPrefetchData`, :class:`MultiProcessMapData`, etc) * Tensorpack's built-in forking dataflows (:class:`MultiProcessRunner`, :class:`MultiProcessMapData`, etc)
and other component that uses dataflows (:class:`InputSource`) and other component that uses dataflows (:class:`InputSource`)
already take care of the responsibility of calling this method. already take care of the responsibility of calling this method.
* You should take the responsibility and follow the above guarantee if you're the caller of a dataflow yourself
(either if you're using dtaflow outside of tensorpack,
or if you're writing a wrapper dataflow).
""" """
pass pass
# These are the old (overly verbose) names for the methods:
def get_data(self):
return self.__iter__()
def size(self):
return self.__len__()
class RNGDataFlow(DataFlow): class RNGDataFlow(DataFlow):
""" A DataFlow with RNG""" """ A DataFlow with RNG"""
...@@ -156,7 +157,7 @@ class RNGDataFlow(DataFlow): ...@@ -156,7 +157,7 @@ class RNGDataFlow(DataFlow):
rng = None rng = None
""" """
``self.rng`` is a ``np.random.RandomState`` instance that is initialized ``self.rng`` is a ``np.random.RandomState`` instance that is initialized
correctly in ``RNGDataFlow.reset_state()``. correctly (with different seeds in each process) in ``RNGDataFlow.reset_state()``.
""" """
def reset_state(self): def reset_state(self):
......
This diff is collapsed.
...@@ -166,7 +166,7 @@ class LMDBDataDecoder(MapData): ...@@ -166,7 +166,7 @@ class LMDBDataDecoder(MapData):
def CaffeLMDB(lmdb_path, shuffle=True, keys=None): def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
""" """
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf. Read a Caffe-format LMDB file where each value contains a ``caffe.Datum`` protobuf.
Produces datapoints of the format: [HWC image, label]. Produces datapoints of the format: [HWC image, label].
Note that Caffe LMDB format is not efficient: it stores serialized raw Note that Caffe LMDB format is not efficient: it stores serialized raw
...@@ -175,9 +175,6 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None): ...@@ -175,9 +175,6 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
Args: Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`. lmdb_path, shuffle, keys: same as :class:`LMDBData`.
Returns:
a :class:`LMDBDataDecoder` instance.
Example: Example:
.. code-block:: python .. code-block:: python
......
...@@ -92,7 +92,7 @@ class AugmentImageComponent(MapDataComponent): ...@@ -92,7 +92,7 @@ class AugmentImageComponent(MapDataComponent):
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
index (int): the index of the image component to be augmented in the datapoint. index (int or str): the index or key of the image component to be augmented in the datapoint.
copy (bool): Some augmentors modify the input images. When copy is copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied, True, a copy will be made before any augmentors are applied,
to keep the original images not modified. to keep the original images not modified.
...@@ -134,8 +134,8 @@ class AugmentImageCoordinates(MapData): ...@@ -134,8 +134,8 @@ class AugmentImageCoordinates(MapData):
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order. augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
img_index (int): the index of the image component to be augmented. img_index (int or str): the index/key of the image component to be augmented.
coords_index (int): the index of the coordinate component to be augmented. coords_index (int or str): the index/key of the coordinate component to be augmented.
copy, catch_exceptions: same as in :class:`AugmentImageComponent` copy, catch_exceptions: same as in :class:`AugmentImageComponent`
""" """
if isinstance(augmentors, AugmentorList): if isinstance(augmentors, AugmentorList):
......
This diff is collapsed.
...@@ -10,11 +10,12 @@ from six.moves import queue ...@@ -10,11 +10,12 @@ from six.moves import queue
from ..utils.concurrency import StoppableThread, enable_death_signal from ..utils.concurrency import StoppableThread, enable_death_signal
from ..utils.serialize import dumps, loads from ..utils.serialize import dumps, loads
from ..utils.develop import log_deprecated
from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow from .base import DataFlow, DataFlowReentrantGuard, ProxyDataFlow
from .common import RepeatedData from .common import RepeatedData
from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error from .parallel import _bind_guard, _get_pipe_name, _MultiProcessZMQDataFlow, _repeat_iter, _zmq_catch_error
__all__ = ['ThreadedMapData', 'MultiThreadMapData', __all__ = ['MultiThreadMapData',
'MultiProcessMapData', 'MultiProcessMapDataZMQ'] 'MultiProcessMapData', 'MultiProcessMapDataZMQ']
...@@ -115,7 +116,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -115,7 +116,7 @@ class MultiThreadMapData(_ParallelMapData):
1. You should avoid starting many threads in your main process to reduce GIL contention. 1. You should avoid starting many threads in your main process to reduce GIL contention.
The threads will only start in the process which calls :meth:`reset_state()`. The threads will only start in the process which calls :meth:`reset_state()`.
Therefore you can use ``PrefetchDataZMQ(MultiThreadMapData(...), 1)`` Therefore you can use ``MultiProcessRunnerZMQ(MultiThreadMapData(...), 1)``
to reduce GIL contention. to reduce GIL contention.
""" """
class _Worker(StoppableThread): class _Worker(StoppableThread):
...@@ -143,16 +144,21 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -143,16 +144,21 @@ class MultiThreadMapData(_ParallelMapData):
finally: finally:
self.stop() self.stop()
def __init__(self, ds, nr_thread, map_func, buffer_size=200, strict=False): def __init__(self, ds, num_thread=None, map_func=None, buffer_size=200, strict=False, nr_thread=None):
""" """
Args: Args:
ds (DataFlow): the dataflow to map ds (DataFlow): the dataflow to map
nr_thread (int): number of threads to use num_thread (int): number of threads to use
map_func (callable): datapoint -> datapoint | None. Return None to map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint. discard/skip the datapoint.
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
nr_thread: deprecated name
""" """
if nr_thread is not None:
log_deprecated("MultiThreadMapData(nr_thread)", "Renamed to 'num_thread'", "2020-01-01")
num_thread = nr_thread
if strict: if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints # In strict mode, buffer size cannot be larger than the total number of datapoints
try: try:
...@@ -161,10 +167,10 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -161,10 +167,10 @@ class MultiThreadMapData(_ParallelMapData):
pass pass
super(MultiThreadMapData, self).__init__(ds, buffer_size, strict) super(MultiThreadMapData, self).__init__(ds, buffer_size, strict)
assert nr_thread > 0, nr_thread assert num_thread > 0, num_thread
self._strict = strict self._strict = strict
self.nr_thread = nr_thread self.num_thread = num_thread
self.map_func = map_func self.map_func = map_func
self._threads = [] self._threads = []
self._evt = None self._evt = None
...@@ -181,7 +187,7 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -181,7 +187,7 @@ class MultiThreadMapData(_ParallelMapData):
self._evt = threading.Event() self._evt = threading.Event()
self._threads = [MultiThreadMapData._Worker( self._threads = [MultiThreadMapData._Worker(
self._in_queue, self._out_queue, self._evt, self.map_func) self._in_queue, self._out_queue, self._evt, self.map_func)
for _ in range(self.nr_thread)] for _ in range(self.num_thread)]
for t in self._threads: for t in self._threads:
t.start() t.start()
...@@ -211,10 +217,6 @@ class MultiThreadMapData(_ParallelMapData): ...@@ -211,10 +217,6 @@ class MultiThreadMapData(_ParallelMapData):
# logger.warn("Cannot join thread {}.".format(p.name)) # logger.warn("Cannot join thread {}.".format(p.name))
# TODO deprecated
ThreadedMapData = MultiThreadMapData
class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
""" """
Same as :class:`MapData`, but start processes to run the mapping function, Same as :class:`MapData`, but start processes to run the mapping function,
...@@ -255,16 +257,20 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -255,16 +257,20 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
dp = self.map_func(dp) dp = self.map_func(dp)
socket.send(dumps(dp), copy=False) socket.send(dumps(dp), copy=False)
def __init__(self, ds, nr_proc, map_func, buffer_size=200, strict=False): def __init__(self, ds, num_proc=None, map_func=None, buffer_size=200, strict=False, nr_proc=None):
""" """
Args: Args:
ds (DataFlow): the dataflow to map ds (DataFlow): the dataflow to map
nr_proc(int): number of threads to use num_proc(int): number of threads to use
map_func (callable): datapoint -> datapoint | None. Return None to map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint. discard/skip the datapoint.
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints in the buffer
strict (bool): use "strict mode", see notes above. strict (bool): use "strict mode", see notes above.
nr_proc: deprecated name
""" """
if nr_proc is not None:
log_deprecated("MultiProcessMapDataZMQ(nr_proc)", "Renamed to 'num_proc'", "2020-01-01")
num_proc = nr_proc
if strict: if strict:
# In strict mode, buffer size cannot be larger than the total number of datapoints # In strict mode, buffer size cannot be larger than the total number of datapoints
try: try:
...@@ -274,8 +280,8 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -274,8 +280,8 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
_ParallelMapData.__init__(self, ds, buffer_size, strict) _ParallelMapData.__init__(self, ds, buffer_size, strict)
_MultiProcessZMQDataFlow.__init__(self) _MultiProcessZMQDataFlow.__init__(self)
assert nr_proc > 0, nr_proc assert num_proc > 0, num_proc
self.nr_proc = nr_proc self.num_proc = num_proc
self.map_func = map_func self.map_func = map_func
self._strict = strict self._strict = strict
self._procs = [] self._procs = []
...@@ -291,11 +297,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -291,11 +297,11 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
pipename = _get_pipe_name('dataflow-map') pipename = _get_pipe_name('dataflow-map')
_bind_guard(self.socket, pipename) _bind_guard(self.socket, pipename)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.nr_proc)] self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.num_proc)]
worker_hwm = int(self._buffer_size * 2 // self.nr_proc) worker_hwm = int(self._buffer_size * 2 // self.num_proc)
self._procs = [MultiProcessMapDataZMQ._Worker( self._procs = [MultiProcessMapDataZMQ._Worker(
self._proc_ids[k], self.map_func, pipename, worker_hwm) self._proc_ids[k], self.map_func, pipename, worker_hwm)
for k in range(self.nr_proc)] for k in range(self.num_proc)]
self._start_processes() self._start_processes()
self._fill_buffer() # pre-fill the bufer self._fill_buffer() # pre-fill the bufer
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment