Commit c59586b2 authored by Yuxin Wu's avatar Yuxin Wu

an initial version of prefetch zmq

parent c939e0b3
...@@ -34,7 +34,7 @@ import mock ...@@ -34,7 +34,7 @@ import mock
#+ ', '.join(["{}={}".format(k,v) for k,v in kwargs.items()]) + ')' #+ ', '.join(["{}={}".format(k,v) for k,v in kwargs.items()]) + ')'
MOCK_MODULES = ['numpy', 'scipy', 'tensorflow', 'scipy.misc', 'h5py', 'nltk', MOCK_MODULES = ['numpy', 'scipy', 'tensorflow', 'scipy.misc', 'h5py', 'nltk',
'cv2', 'scipy.io'] 'cv2', 'scipy.io', 'dill', 'zmq']
for mod_name in MOCK_MODULES: for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock(name=mod_name) sys.modules[mod_name] = mock.Mock(name=mod_name)
......
...@@ -4,3 +4,4 @@ scipy ...@@ -4,3 +4,4 @@ scipy
tqdm tqdm
h5py h5py
nltk nltk
dill
...@@ -46,6 +46,7 @@ NR_DP_TEST = args.number ...@@ -46,6 +46,7 @@ NR_DP_TEST = args.number
logger.info("Testing dataflow speed:") logger.info("Testing dataflow speed:")
with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar: with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar:
for idx, dp in enumerate(config.dataset.get_data()): for idx, dp in enumerate(config.dataset.get_data()):
del dp
if idx > NR_DP_TEST: if idx > NR_DP_TEST:
break break
pbar.update() pbar.update()
......
...@@ -3,13 +3,19 @@ ...@@ -3,13 +3,19 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import multiprocessing import multiprocessing
from threading import Thread
from six.moves import range from six.moves import range
from six.moves.queue import Queue
import uuid
import zmq
import os
from .base import ProxyDataFlow from .base import ProxyDataFlow
from ..utils.concurrency import ensure_proc_terminate from ..utils.concurrency import ensure_proc_terminate
from ..utils.serialize import *
from ..utils import logger from ..utils import logger
__all__ = ['PrefetchData'] __all__ = ['PrefetchData', 'PrefetchDataZMQ']
class PrefetchProcess(multiprocessing.Process): class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue): def __init__(self, ds, queue):
...@@ -62,3 +68,65 @@ class PrefetchData(ProxyDataFlow): ...@@ -62,3 +68,65 @@ class PrefetchData(ProxyDataFlow):
x.terminate() x.terminate()
logger.info("Prefetch process exited.") logger.info("Prefetch process exited.")
class PrefetchProcessZMQ(multiprocessing.Process):
def __init__(self, ds, conn_name, qsize=1):
super(PrefetchProcessZMQ, self).__init__()
self.ds = ds
self.qsize = qsize
self.conn_name = conn_name
def run(self):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH)
self.socket.set_hwm(self.qsize)
self.socket.connect(self.conn_name)
self.id = os.getpid()
cnt = 0
while True:
for dp in self.ds.get_data():
self.socket.send(dumps(dp))
cnt += 1
print("Proc {} send {}".format(self.id, cnt))
class PrefetchDataZMQ(ProxyDataFlow):
""" Work the same as `PrefetchData`, but faster. """
def __init__(self, ds, nr_prefetch, nr_proc=1):
super(PrefetchDataZMQ, self).__init__(ds)
self.ds = ds
self._size = ds.size()
self.nr_proc = nr_proc
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
name = "ipc://whatever-" + str(uuid.uuid1())[:6]
self.socket.bind(name)
# TODO local queue again? probably don't affect training
self.queue = Queue(maxsize=nr_prefetch)
def enque():
while True:
self.queue.put(loads(self.socket.recv(copy=False)))
self.th = Thread(target=enque)
self.th.daemon = True
self.th.start()
self.procs = [PrefetchProcessZMQ(self.ds, name)
for _ in range(self.nr_proc)]
for x in self.procs:
x.start()
def get_data(self):
for _ in range(self._size):
dp = self.queue.get()
yield dp
#print(self.queue.qsize())
def __del__(self):
logger.info("Prefetch process exiting...")
self.queue.close()
for x in self.procs:
x.terminate()
self.th.terminate()
logger.info("Prefetch process exited.")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: serialize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
#import dill
__all__ = ['loads', 'dumps']
def dumps(obj):
#return dill.dumps(obj)
return msgpack.dumps(obj, use_bin_type=True)
def loads(buf):
#return dill.loads(buf)
return msgpack.loads(buf)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment