Commit 1401d30d authored by Yuxin Wu's avatar Yuxin Wu

mark more DataFlow as non-reentrant

parent 79b9d0eb
......@@ -7,7 +7,7 @@ a __Python generator__ which yields preprocessed ImageNet images and labels as f
Since it is simply a generator interface, you can use the DataFlow in other Python-based frameworks (e.g. Keras)
or your own code as well.
**What we are going to do**: We'll use ILSVRC12 training set, which contains 1.28 million images.
**What we are going to do**: We'll use ILSVRC12 dataset, which contains 1.28 million images.
The original images (JPEG compressed) are 140G in total.
The average resolution is about 400x350 <sup>[[1]]</sup>.
Following the [ResNet example](../examples/ResNet), we need images in their original resolution,
......@@ -16,18 +16,18 @@ then apply complicated preprocessing to it.
We will need to reach a speed of, roughly 1k ~ 2k images per second, to keep GPUs busy.
Some things to know before reading:
1. Having a fast Python generator **alone** may or may not help with your overall training speed.
You need mechanisms to hide the latency of all preprocessing stages, as mentioned in the
[previous tutorial](http://localhost:8000/tutorial/input-source.html).
2. Requirements on reading training set and validation set are different.
1. Having a fast Python generator **alone** may or may not improve your overall training speed.
You need mechanisms to hide the latency of **all** preprocessing stages, as mentioned in the
[previous tutorial](http://tensorpack.readthedocs.io/en/latest/tutorial/input-source.html).
2. Reading training set and validation set are different.
In training it's OK to reorder, regroup, or even duplicate some datapoints, as long as the
distribution roughly stays the same.
But in validation we often need the exact set of data, to be able to compute the correct error.
data distribution roughly stays the same.
But in validation we often need the exact set of data, to be able to compute a correct and comparable score.
This will affect how we build the DataFlow.
3. The actual performance would depend on not only the disk, but also memory (for caching) and CPU (for data processing).
You may need to tune the parameters (#processes, #threads, size of buffer, etc.)
or change the pipeline for new tasks and new machines to achieve the best performance.
4. This tutorial could be too complicated for people new to system architectures, but you do need these to be able to run fast enough on ImageNet-sized dataset.
4. This tutorial could be a bit complicated for people new to system architectures, but you do need these to be able to run fast enough on ImageNet-sized dataset.
However, for smaller datasets (e.g. several GBs of images with lightweight preprocessing), a simple reader plus some prefetch should work well enough.
Figure out the bottleneck first, before trying to optimize any piece in the whole system.
......
......@@ -9,7 +9,7 @@ from termcolor import colored
from collections import deque, defaultdict
from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from .base import DataFlow, ProxyDataFlow, RNGDataFlow, DataFlowReentrantGuard
from ..utils import logger
from ..utils.utils import get_tqdm, get_rng
from ..utils.develop import log_deprecated
......@@ -166,12 +166,14 @@ class BatchDataByShape(BatchData):
"""
super(BatchDataByShape, self).__init__(ds, batch_size, remainder=False)
self.idx = idx
self._guard = DataFlowReentrantGuard()
def reset_state(self):
super(BatchDataByShape, self).reset_state()
self.holder = defaultdict(list)
def get_data(self):
with self._guard:
for dp in self.ds.get_data():
shp = dp[self.idx].shape
holder = self.holder[shp]
......@@ -194,11 +196,13 @@ class FixedSizeData(ProxyDataFlow):
super(FixedSizeData, self).__init__(ds)
self._size = int(size)
self.itr = None
self._guard = DataFlowReentrantGuard()
def size(self):
return self._size
def get_data(self):
with self._guard:
if self.itr is None:
self.itr = self.ds.get_data()
cnt = 0
......@@ -522,6 +526,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
shuffle_interval = int(buffer_size // 3)
self.shuffle_interval = shuffle_interval
self.nr_reuse = nr_reuse
self._guard = DataFlowReentrantGuard()
def reset_state(self):
ProxyDataFlow.reset_state(self)
......@@ -535,6 +540,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
self.q.append(dp)
def get_data(self):
with self._guard:
# fill queue
while self.q.maxlen > len(self.q):
self._add_data()
......@@ -564,6 +570,7 @@ class CacheData(ProxyDataFlow):
shuffle (bool): whether to shuffle the datapoints before producing them.
"""
self.shuffle = shuffle
self._guard = DataFlowReentrantGuard()
super(CacheData, self).__init__(ds)
def reset_state(self):
......@@ -573,6 +580,7 @@ class CacheData(ProxyDataFlow):
self.buffer = []
def get_data(self):
with self._guard:
if len(self.buffer):
if self.shuffle:
self.rng.shuffle(self.buffer)
......
......@@ -13,7 +13,7 @@ from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb
from ..utils.serialize import loads
from ..utils.argtools import log_once
from .base import RNGDataFlow, DataFlow
from .base import RNGDataFlow, DataFlow, DataFlowReentrantGuard
from .common import MapData
__all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint',
......@@ -83,6 +83,7 @@ class LMDBData(RNGDataFlow):
self._size = self._txn.stat()['entries']
self._set_keys(keys)
logger.info("Found {} entries in {}".format(self._size, self._lmdb_path))
self._guard = DataFlowReentrantGuard()
def _set_keys(self, keys=None):
def find_keys(txn, size):
......@@ -128,6 +129,7 @@ class LMDBData(RNGDataFlow):
return self._size
def get_data(self):
with self._guard:
if not self._shuffle:
c = self._txn.cursor()
while c.next():
......@@ -265,7 +267,7 @@ class TFRecordData(DataFlow):
size (int): total number of records, because this metadata is not
stored in the tfrecord file.
"""
self._gen = tf.python_io.tf_record_iterator(path)
self._path = path
self._size = int(size)
def size(self):
......@@ -274,7 +276,8 @@ class TFRecordData(DataFlow):
return super(TFRecordData, self).size()
def get_data(self):
for dp in self._gen:
gen = tf.python_io.tf_record_iterator(self._path)
for dp in gen:
yield loads(dp)
from ..utils.develop import create_dummy_class # noqa
......
......@@ -5,7 +5,7 @@
import time
from collections import deque
from .base import DataFlow
from .base import DataFlow, DataFlowReentrantGuard
from ..utils import logger
from ..utils.utils import get_tqdm
from ..utils.serialize import dumps, loads, dumps_for_tfop
......@@ -72,12 +72,14 @@ class RemoteDataZMQ(DataFlow):
assert addr1
self._addr1 = addr1
self._addr2 = addr2
self._guard = DataFlowReentrantGuard()
def reset_state(self):
self.cnt1 = 0
self.cnt2 = 0
def get_data(self):
with self._guard:
try:
ctx = zmq.Context()
if self._addr2 is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment