Commit 8bdd9c85 authored by Yuxin Wu's avatar Yuxin Wu

speed up prefetch

parent c59586b2
......@@ -97,7 +97,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain:
ds = PrefetchData(ds, 10, 5)
ds = PrefetchDataZMQ(ds, 5)
return ds
def get_config():
......
......@@ -145,7 +145,8 @@ class FakeData(DataFlow):
def get_data(self):
for _ in range(self._size):
yield [self.rng.random_sample(k) for k in self.shapes]
yield [self.rng.random_sample(k).astype('float32') for k in self.shapes]
#yield [self.rng.random_sample(k) for k in self.shapes]
class MapData(ProxyDataFlow):
""" Apply map/filter a function on the datapoint"""
......
......@@ -11,7 +11,7 @@ from six.moves import range
try:
import h5py
except ImportError:
logger.error("Error in 'import h5py'. HDF5Data won't be imported.")
logger.error("Error in 'import h5py'. HDF5Data won't be available.")
__all__ = []
else:
__all__ = ['HDF5Data']
......
......@@ -7,7 +7,6 @@ from threading import Thread
from six.moves import range
from six.moves.queue import Queue
import uuid
import zmq
import os
from .base import ProxyDataFlow
......@@ -15,7 +14,14 @@ from ..utils.concurrency import ensure_proc_terminate
from ..utils.serialize import *
from ..utils import logger
__all__ = ['PrefetchData', 'PrefetchDataZMQ']
try:
import zmq
except ImportError:
logger.error("Error in 'import zmq'. PrefetchDataZMQ won't be available.")
__all__ = ['PrefetchData']
else:
__all__ = ['PrefetchData', 'PrefetchDataZMQ']
class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue):
......@@ -69,64 +75,57 @@ class PrefetchData(ProxyDataFlow):
logger.info("Prefetch process exited.")
class PrefetchProcessZMQ(multiprocessing.Process):
def __init__(self, ds, conn_name, qsize=1):
def __init__(self, ds, conn_name):
"""
:param ds: a `DataFlow` instance.
:param conn_name: the name of the IPC connection
"""
super(PrefetchProcessZMQ, self).__init__()
self.ds = ds
self.qsize = qsize
self.conn_name = conn_name
def run(self):
self.ds.reset_state()
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH)
self.socket.set_hwm(self.qsize)
self.socket.set_hwm(1)
self.socket.connect(self.conn_name)
self.id = os.getpid()
cnt = 0
while True:
for dp in self.ds.get_data():
self.socket.send(dumps(dp))
cnt += 1
print("Proc {} send {}".format(self.id, cnt))
self.socket.send(dumps(dp), copy=False)
class PrefetchDataZMQ(ProxyDataFlow):
""" Work the same as `PrefetchData`, but faster. """
def __init__(self, ds, nr_prefetch, nr_proc=1):
def __init__(self, ds, nr_proc=1):
"""
:param ds: a `DataFlow` instance.
:param nr_proc: number of processes to use. When larger than 1, order
of datapoints will be random.
"""
super(PrefetchDataZMQ, self).__init__(ds)
self.ds = ds
self._size = ds.size()
self.nr_proc = nr_proc
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
name = "ipc://whatever-" + str(uuid.uuid1())[:6]
self.socket.bind(name)
self.pipename = "ipc://dataflow-pipe-" + str(uuid.uuid1())[:6]
self.socket.set_hwm(5) # a little bit faster than default, don't know why
self.socket.bind(self.pipename)
# TODO local queue again? probably don't affect training
self.queue = Queue(maxsize=nr_prefetch)
def enque():
while True:
self.queue.put(loads(self.socket.recv(copy=False)))
self.th = Thread(target=enque)
self.th.daemon = True
self.th.start()
self.procs = [PrefetchProcessZMQ(self.ds, name)
self.procs = [PrefetchProcessZMQ(self.ds, self.pipename)
for _ in range(self.nr_proc)]
for x in self.procs:
x.start()
def get_data(self):
for _ in range(self._size):
dp = self.queue.get()
dp = loads(self.socket.recv(copy=False))
yield dp
#print(self.queue.qsize())
def __del__(self):
logger.info("Prefetch process exiting...")
self.queue.close()
self.context.destroy(0)
for x in self.procs:
x.terminate()
self.th.terminate()
logger.info("Prefetch process exited.")
......@@ -3,17 +3,17 @@
# File: serialize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
#import dill
#import msgpack
#import msgpack_numpy
#msgpack_numpy.patch()
import dill
__all__ = ['loads', 'dumps']
def dumps(obj):
#return dill.dumps(obj)
return msgpack.dumps(obj, use_bin_type=True)
return dill.dumps(obj)
#return msgpack.dumps(obj, use_bin_type=True)
def loads(buf):
#return dill.loads(buf)
return msgpack.loads(buf)
return dill.loads(buf)
#return msgpack.loads(buf)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment