Commit 1401d30d authored by Yuxin Wu's avatar Yuxin Wu

mark more DataFlow as non-reentrant

parent 79b9d0eb
...@@ -7,7 +7,7 @@ a __Python generator__ which yields preprocessed ImageNet images and labels as f ...@@ -7,7 +7,7 @@ a __Python generator__ which yields preprocessed ImageNet images and labels as f
Since it is simply a generator interface, you can use the DataFlow in other Python-based frameworks (e.g. Keras) Since it is simply a generator interface, you can use the DataFlow in other Python-based frameworks (e.g. Keras)
or your own code as well. or your own code as well.
**What we are going to do**: We'll use ILSVRC12 training set, which contains 1.28 million images. **What we are going to do**: We'll use ILSVRC12 dataset, which contains 1.28 million images.
The original images (JPEG compressed) are 140G in total. The original images (JPEG compressed) are 140G in total.
The average resolution is about 400x350 <sup>[[1]]</sup>. The average resolution is about 400x350 <sup>[[1]]</sup>.
Following the [ResNet example](../examples/ResNet), we need images in their original resolution, Following the [ResNet example](../examples/ResNet), we need images in their original resolution,
...@@ -16,18 +16,18 @@ then apply complicated preprocessing to it. ...@@ -16,18 +16,18 @@ then apply complicated preprocessing to it.
We will need to reach a speed of, roughly 1k ~ 2k images per second, to keep GPUs busy. We will need to reach a speed of, roughly 1k ~ 2k images per second, to keep GPUs busy.
Some things to know before reading: Some things to know before reading:
1. Having a fast Python generator **alone** may or may not help with your overall training speed. 1. Having a fast Python generator **alone** may or may not improve your overall training speed.
You need mechanisms to hide the latency of all preprocessing stages, as mentioned in the You need mechanisms to hide the latency of **all** preprocessing stages, as mentioned in the
[previous tutorial](http://localhost:8000/tutorial/input-source.html). [previous tutorial](http://tensorpack.readthedocs.io/en/latest/tutorial/input-source.html).
2. Requirements on reading training set and validation set are different. 2. Reading training set and validation set are different.
In training it's OK to reorder, regroup, or even duplicate some datapoints, as long as the In training it's OK to reorder, regroup, or even duplicate some datapoints, as long as the
distribution roughly stays the same. data distribution roughly stays the same.
But in validation we often need the exact set of data, to be able to compute the correct error. But in validation we often need the exact set of data, to be able to compute a correct and comparable score.
This will affect how we build the DataFlow. This will affect how we build the DataFlow.
3. The actual performance would depend on not only the disk, but also memory (for caching) and CPU (for data processing). 3. The actual performance would depend on not only the disk, but also memory (for caching) and CPU (for data processing).
You may need to tune the parameters (#processes, #threads, size of buffer, etc.) You may need to tune the parameters (#processes, #threads, size of buffer, etc.)
or change the pipeline for new tasks and new machines to achieve the best performance. or change the pipeline for new tasks and new machines to achieve the best performance.
4. This tutorial could be too complicated for people new to system architectures, but you do need these to be able to run fast enough on ImageNet-sized dataset. 4. This tutorial could be a bit complicated for people new to system architectures, but you do need these to be able to run fast enough on ImageNet-sized dataset.
However, for smaller datasets (e.g. several GBs of images with lightweight preprocessing), a simple reader plus some prefetch should work well enough. However, for smaller datasets (e.g. several GBs of images with lightweight preprocessing), a simple reader plus some prefetch should work well enough.
Figure out the bottleneck first, before trying to optimize any piece in the whole system. Figure out the bottleneck first, before trying to optimize any piece in the whole system.
......
...@@ -9,7 +9,7 @@ from termcolor import colored ...@@ -9,7 +9,7 @@ from termcolor import colored
from collections import deque, defaultdict from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow from .base import DataFlow, ProxyDataFlow, RNGDataFlow, DataFlowReentrantGuard
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm, get_rng from ..utils.utils import get_tqdm, get_rng
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
...@@ -166,19 +166,21 @@ class BatchDataByShape(BatchData): ...@@ -166,19 +166,21 @@ class BatchDataByShape(BatchData):
""" """
super(BatchDataByShape, self).__init__(ds, batch_size, remainder=False) super(BatchDataByShape, self).__init__(ds, batch_size, remainder=False)
self.idx = idx self.idx = idx
self._guard = DataFlowReentrantGuard()
def reset_state(self): def reset_state(self):
super(BatchDataByShape, self).reset_state() super(BatchDataByShape, self).reset_state()
self.holder = defaultdict(list) self.holder = defaultdict(list)
def get_data(self): def get_data(self):
for dp in self.ds.get_data(): with self._guard:
shp = dp[self.idx].shape for dp in self.ds.get_data():
holder = self.holder[shp] shp = dp[self.idx].shape
holder.append(dp) holder = self.holder[shp]
if len(holder) == self.batch_size: holder.append(dp)
yield BatchData._aggregate_batch(holder) if len(holder) == self.batch_size:
del holder[:] yield BatchData._aggregate_batch(holder)
del holder[:]
class FixedSizeData(ProxyDataFlow): class FixedSizeData(ProxyDataFlow):
...@@ -194,25 +196,27 @@ class FixedSizeData(ProxyDataFlow): ...@@ -194,25 +196,27 @@ class FixedSizeData(ProxyDataFlow):
super(FixedSizeData, self).__init__(ds) super(FixedSizeData, self).__init__(ds)
self._size = int(size) self._size = int(size)
self.itr = None self.itr = None
self._guard = DataFlowReentrantGuard()
def size(self): def size(self):
return self._size return self._size
def get_data(self): def get_data(self):
if self.itr is None: with self._guard:
self.itr = self.ds.get_data() if self.itr is None:
cnt = 0
while True:
try:
dp = next(self.itr)
except StopIteration:
self.itr = self.ds.get_data() self.itr = self.ds.get_data()
dp = next(self.itr) cnt = 0
while True:
try:
dp = next(self.itr)
except StopIteration:
self.itr = self.ds.get_data()
dp = next(self.itr)
cnt += 1 cnt += 1
yield dp yield dp
if cnt == self._size: if cnt == self._size:
return return
class MapData(ProxyDataFlow): class MapData(ProxyDataFlow):
...@@ -522,6 +526,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -522,6 +526,7 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
shuffle_interval = int(buffer_size // 3) shuffle_interval = int(buffer_size // 3)
self.shuffle_interval = shuffle_interval self.shuffle_interval = shuffle_interval
self.nr_reuse = nr_reuse self.nr_reuse = nr_reuse
self._guard = DataFlowReentrantGuard()
def reset_state(self): def reset_state(self):
ProxyDataFlow.reset_state(self) ProxyDataFlow.reset_state(self)
...@@ -535,23 +540,24 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -535,23 +540,24 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
self.q.append(dp) self.q.append(dp)
def get_data(self): def get_data(self):
# fill queue with self._guard:
while self.q.maxlen > len(self.q): # fill queue
self._add_data() while self.q.maxlen > len(self.q):
sz = self.size()
cnt = 0
while True:
self.rng.shuffle(self.q)
for _ in range(self.shuffle_interval):
# the inner loop maintains the queue size (almost) unchanged
for _ in range(self.nr_reuse):
yield self.q.popleft()
cnt += self.nr_reuse
if cnt >= sz:
return
self._add_data() self._add_data()
sz = self.size()
cnt = 0
while True:
self.rng.shuffle(self.q)
for _ in range(self.shuffle_interval):
# the inner loop maintains the queue size (almost) unchanged
for _ in range(self.nr_reuse):
yield self.q.popleft()
cnt += self.nr_reuse
if cnt >= sz:
return
self._add_data()
class CacheData(ProxyDataFlow): class CacheData(ProxyDataFlow):
""" """
...@@ -564,6 +570,7 @@ class CacheData(ProxyDataFlow): ...@@ -564,6 +570,7 @@ class CacheData(ProxyDataFlow):
shuffle (bool): whether to shuffle the datapoints before producing them. shuffle (bool): whether to shuffle the datapoints before producing them.
""" """
self.shuffle = shuffle self.shuffle = shuffle
self._guard = DataFlowReentrantGuard()
super(CacheData, self).__init__(ds) super(CacheData, self).__init__(ds)
def reset_state(self): def reset_state(self):
...@@ -573,15 +580,16 @@ class CacheData(ProxyDataFlow): ...@@ -573,15 +580,16 @@ class CacheData(ProxyDataFlow):
self.buffer = [] self.buffer = []
def get_data(self): def get_data(self):
if len(self.buffer): with self._guard:
if self.shuffle: if len(self.buffer):
self.rng.shuffle(self.buffer) if self.shuffle:
for dp in self.buffer: self.rng.shuffle(self.buffer)
yield dp for dp in self.buffer:
else: yield dp
for dp in self.ds.get_data(): else:
yield dp for dp in self.ds.get_data():
self.buffer.append(dp) yield dp
self.buffer.append(dp)
class PrintData(ProxyDataFlow): class PrintData(ProxyDataFlow):
......
...@@ -13,7 +13,7 @@ from ..utils.timer import timed_operation ...@@ -13,7 +13,7 @@ from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb from ..utils.loadcaffe import get_caffe_pb
from ..utils.serialize import loads from ..utils.serialize import loads
from ..utils.argtools import log_once from ..utils.argtools import log_once
from .base import RNGDataFlow, DataFlow from .base import RNGDataFlow, DataFlow, DataFlowReentrantGuard
from .common import MapData from .common import MapData
__all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint', __all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint',
...@@ -83,6 +83,7 @@ class LMDBData(RNGDataFlow): ...@@ -83,6 +83,7 @@ class LMDBData(RNGDataFlow):
self._size = self._txn.stat()['entries'] self._size = self._txn.stat()['entries']
self._set_keys(keys) self._set_keys(keys)
logger.info("Found {} entries in {}".format(self._size, self._lmdb_path)) logger.info("Found {} entries in {}".format(self._size, self._lmdb_path))
self._guard = DataFlowReentrantGuard()
def _set_keys(self, keys=None): def _set_keys(self, keys=None):
def find_keys(txn, size): def find_keys(txn, size):
...@@ -128,17 +129,18 @@ class LMDBData(RNGDataFlow): ...@@ -128,17 +129,18 @@ class LMDBData(RNGDataFlow):
return self._size return self._size
def get_data(self): def get_data(self):
if not self._shuffle: with self._guard:
c = self._txn.cursor() if not self._shuffle:
while c.next(): c = self._txn.cursor()
k, v = c.item() while c.next():
if k != b'__keys__': k, v = c.item()
if k != b'__keys__':
yield [k, v]
else:
self.rng.shuffle(self.keys)
for k in self.keys:
v = self._txn.get(k)
yield [k, v] yield [k, v]
else:
self.rng.shuffle(self.keys)
for k in self.keys:
v = self._txn.get(k)
yield [k, v]
class LMDBDataDecoder(MapData): class LMDBDataDecoder(MapData):
...@@ -265,7 +267,7 @@ class TFRecordData(DataFlow): ...@@ -265,7 +267,7 @@ class TFRecordData(DataFlow):
size (int): total number of records, because this metadata is not size (int): total number of records, because this metadata is not
stored in the tfrecord file. stored in the tfrecord file.
""" """
self._gen = tf.python_io.tf_record_iterator(path) self._path = path
self._size = int(size) self._size = int(size)
def size(self): def size(self):
...@@ -274,7 +276,8 @@ class TFRecordData(DataFlow): ...@@ -274,7 +276,8 @@ class TFRecordData(DataFlow):
return super(TFRecordData, self).size() return super(TFRecordData, self).size()
def get_data(self): def get_data(self):
for dp in self._gen: gen = tf.python_io.tf_record_iterator(self._path)
for dp in gen:
yield loads(dp) yield loads(dp)
from ..utils.develop import create_dummy_class # noqa from ..utils.develop import create_dummy_class # noqa
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import time import time
from collections import deque from collections import deque
from .base import DataFlow from .base import DataFlow, DataFlowReentrantGuard
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm from ..utils.utils import get_tqdm
from ..utils.serialize import dumps, loads, dumps_for_tfop from ..utils.serialize import dumps, loads, dumps_for_tfop
...@@ -72,47 +72,49 @@ class RemoteDataZMQ(DataFlow): ...@@ -72,47 +72,49 @@ class RemoteDataZMQ(DataFlow):
assert addr1 assert addr1
self._addr1 = addr1 self._addr1 = addr1
self._addr2 = addr2 self._addr2 = addr2
self._guard = DataFlowReentrantGuard()
def reset_state(self): def reset_state(self):
self.cnt1 = 0 self.cnt1 = 0
self.cnt2 = 0 self.cnt2 = 0
def get_data(self): def get_data(self):
try: with self._guard:
ctx = zmq.Context() try:
if self._addr2 is None: ctx = zmq.Context()
socket = ctx.socket(zmq.PULL) if self._addr2 is None:
socket.set_hwm(50) socket = ctx.socket(zmq.PULL)
socket.bind(self._addr1) socket.set_hwm(50)
socket.bind(self._addr1)
while True:
dp = loads(socket.recv(copy=False).bytes) while True:
yield dp dp = loads(socket.recv(copy=False).bytes)
self.cnt1 += 1
else:
socket1 = ctx.socket(zmq.PULL)
socket1.set_hwm(50)
socket1.bind(self._addr1)
socket2 = ctx.socket(zmq.PULL)
socket2.set_hwm(50)
socket2.bind(self._addr2)
poller = zmq.Poller()
poller.register(socket1, zmq.POLLIN)
poller.register(socket2, zmq.POLLIN)
while True:
evts = poller.poll()
for sock, evt in evts:
dp = loads(sock.recv(copy=False).bytes)
yield dp yield dp
if sock == socket1: self.cnt1 += 1
self.cnt1 += 1 else:
else: socket1 = ctx.socket(zmq.PULL)
self.cnt2 += 1 socket1.set_hwm(50)
finally: socket1.bind(self._addr1)
ctx.destroy(linger=0)
socket2 = ctx.socket(zmq.PULL)
socket2.set_hwm(50)
socket2.bind(self._addr2)
poller = zmq.Poller()
poller.register(socket1, zmq.POLLIN)
poller.register(socket2, zmq.POLLIN)
while True:
evts = poller.poll()
for sock, evt in evts:
dp = loads(sock.recv(copy=False).bytes)
yield dp
if sock == socket1:
self.cnt1 += 1
else:
self.cnt2 += 1
finally:
ctx.destroy(linger=0)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment