Commit 8bdd9c85 authored by Yuxin Wu's avatar Yuxin Wu

speed up prefetch

parent c59586b2
...@@ -97,7 +97,7 @@ def get_data(train_or_test): ...@@ -97,7 +97,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain) ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchData(ds, 10, 5) ds = PrefetchDataZMQ(ds, 5)
return ds return ds
def get_config(): def get_config():
......
...@@ -145,7 +145,8 @@ class FakeData(DataFlow): ...@@ -145,7 +145,8 @@ class FakeData(DataFlow):
def get_data(self): def get_data(self):
for _ in range(self._size): for _ in range(self._size):
yield [self.rng.random_sample(k) for k in self.shapes] yield [self.rng.random_sample(k).astype('float32') for k in self.shapes]
#yield [self.rng.random_sample(k) for k in self.shapes]
class MapData(ProxyDataFlow): class MapData(ProxyDataFlow):
""" Apply map/filter a function on the datapoint""" """ Apply map/filter a function on the datapoint"""
......
...@@ -11,7 +11,7 @@ from six.moves import range ...@@ -11,7 +11,7 @@ from six.moves import range
try: try:
import h5py import h5py
except ImportError: except ImportError:
logger.error("Error in 'import h5py'. HDF5Data won't be imported.") logger.error("Error in 'import h5py'. HDF5Data won't be available.")
__all__ = [] __all__ = []
else: else:
__all__ = ['HDF5Data'] __all__ = ['HDF5Data']
......
...@@ -7,7 +7,6 @@ from threading import Thread ...@@ -7,7 +7,6 @@ from threading import Thread
from six.moves import range from six.moves import range
from six.moves.queue import Queue from six.moves.queue import Queue
import uuid import uuid
import zmq
import os import os
from .base import ProxyDataFlow from .base import ProxyDataFlow
...@@ -15,7 +14,14 @@ from ..utils.concurrency import ensure_proc_terminate ...@@ -15,7 +14,14 @@ from ..utils.concurrency import ensure_proc_terminate
from ..utils.serialize import * from ..utils.serialize import *
from ..utils import logger from ..utils import logger
__all__ = ['PrefetchData', 'PrefetchDataZMQ'] try:
import zmq
except ImportError:
logger.error("Error in 'import zmq'. PrefetchDataZMQ won't be available.")
__all__ = ['PrefetchData']
else:
__all__ = ['PrefetchData', 'PrefetchDataZMQ']
class PrefetchProcess(multiprocessing.Process): class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue): def __init__(self, ds, queue):
...@@ -69,64 +75,57 @@ class PrefetchData(ProxyDataFlow): ...@@ -69,64 +75,57 @@ class PrefetchData(ProxyDataFlow):
logger.info("Prefetch process exited.") logger.info("Prefetch process exited.")
class PrefetchProcessZMQ(multiprocessing.Process): class PrefetchProcessZMQ(multiprocessing.Process):
def __init__(self, ds, conn_name, qsize=1): def __init__(self, ds, conn_name):
"""
:param ds: a `DataFlow` instance.
:param conn_name: the name of the IPC connection
"""
super(PrefetchProcessZMQ, self).__init__() super(PrefetchProcessZMQ, self).__init__()
self.ds = ds self.ds = ds
self.qsize = qsize
self.conn_name = conn_name self.conn_name = conn_name
def run(self): def run(self):
self.ds.reset_state()
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH) self.socket = self.context.socket(zmq.PUSH)
self.socket.set_hwm(self.qsize) self.socket.set_hwm(1)
self.socket.connect(self.conn_name) self.socket.connect(self.conn_name)
self.id = os.getpid()
cnt = 0
while True: while True:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
self.socket.send(dumps(dp)) self.socket.send(dumps(dp), copy=False)
cnt += 1
print("Proc {} send {}".format(self.id, cnt))
class PrefetchDataZMQ(ProxyDataFlow): class PrefetchDataZMQ(ProxyDataFlow):
""" Work the same as `PrefetchData`, but faster. """ """ Work the same as `PrefetchData`, but faster. """
def __init__(self, ds, nr_prefetch, nr_proc=1): def __init__(self, ds, nr_proc=1):
"""
:param ds: a `DataFlow` instance.
:param nr_proc: number of processes to use. When larger than 1, order
of datapoints will be random.
"""
super(PrefetchDataZMQ, self).__init__(ds) super(PrefetchDataZMQ, self).__init__(ds)
self.ds = ds
self._size = ds.size() self._size = ds.size()
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL) self.socket = self.context.socket(zmq.PULL)
name = "ipc://whatever-" + str(uuid.uuid1())[:6] self.pipename = "ipc://dataflow-pipe-" + str(uuid.uuid1())[:6]
self.socket.bind(name) self.socket.set_hwm(5) # a little bit faster than default, don't know why
self.socket.bind(self.pipename)
# TODO local queue again? probably don't affect training
self.queue = Queue(maxsize=nr_prefetch) self.procs = [PrefetchProcessZMQ(self.ds, self.pipename)
def enque():
while True:
self.queue.put(loads(self.socket.recv(copy=False)))
self.th = Thread(target=enque)
self.th.daemon = True
self.th.start()
self.procs = [PrefetchProcessZMQ(self.ds, name)
for _ in range(self.nr_proc)] for _ in range(self.nr_proc)]
for x in self.procs: for x in self.procs:
x.start() x.start()
def get_data(self): def get_data(self):
for _ in range(self._size): for _ in range(self._size):
dp = self.queue.get() dp = loads(self.socket.recv(copy=False))
yield dp yield dp
#print(self.queue.qsize())
def __del__(self): def __del__(self):
logger.info("Prefetch process exiting...") logger.info("Prefetch process exiting...")
self.queue.close() self.context.destroy(0)
for x in self.procs: for x in self.procs:
x.terminate() x.terminate()
self.th.terminate()
logger.info("Prefetch process exited.") logger.info("Prefetch process exited.")
...@@ -3,17 +3,17 @@ ...@@ -3,17 +3,17 @@
# File: serialize.py # File: serialize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import msgpack #import msgpack
import msgpack_numpy #import msgpack_numpy
msgpack_numpy.patch() #msgpack_numpy.patch()
#import dill import dill
__all__ = ['loads', 'dumps'] __all__ = ['loads', 'dumps']
def dumps(obj): def dumps(obj):
#return dill.dumps(obj) return dill.dumps(obj)
return msgpack.dumps(obj, use_bin_type=True) #return msgpack.dumps(obj, use_bin_type=True)
def loads(buf): def loads(buf):
#return dill.loads(buf) return dill.loads(buf)
return msgpack.loads(buf) #return msgpack.loads(buf)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment