Commit c59586b2 authored by Yuxin Wu's avatar Yuxin Wu

an initial version of prefetch zmq

parent c939e0b3
......@@ -34,7 +34,7 @@ import mock
#+ ', '.join(["{}={}".format(k,v) for k,v in kwargs.items()]) + ')'
MOCK_MODULES = ['numpy', 'scipy', 'tensorflow', 'scipy.misc', 'h5py', 'nltk',
'cv2', 'scipy.io']
'cv2', 'scipy.io', 'dill', 'zmq']
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock(name=mod_name)
......
......@@ -46,6 +46,7 @@ NR_DP_TEST = args.number
logger.info("Testing dataflow speed:")
with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar:
for idx, dp in enumerate(config.dataset.get_data()):
del dp
if idx > NR_DP_TEST:
break
pbar.update()
......
......@@ -3,13 +3,19 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import multiprocessing
from threading import Thread
from six.moves import range
from six.moves.queue import Queue
import uuid
import zmq
import os
from .base import ProxyDataFlow
from ..utils.concurrency import ensure_proc_terminate
from ..utils.serialize import *
from ..utils import logger
__all__ = ['PrefetchData']
__all__ = ['PrefetchData', 'PrefetchDataZMQ']
class PrefetchProcess(multiprocessing.Process):
def __init__(self, ds, queue):
......@@ -62,3 +68,65 @@ class PrefetchData(ProxyDataFlow):
x.terminate()
logger.info("Prefetch process exited.")
class PrefetchProcessZMQ(multiprocessing.Process):
def __init__(self, ds, conn_name, qsize=1):
super(PrefetchProcessZMQ, self).__init__()
self.ds = ds
self.qsize = qsize
self.conn_name = conn_name
def run(self):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUSH)
self.socket.set_hwm(self.qsize)
self.socket.connect(self.conn_name)
self.id = os.getpid()
cnt = 0
while True:
for dp in self.ds.get_data():
self.socket.send(dumps(dp))
cnt += 1
print("Proc {} send {}".format(self.id, cnt))
class PrefetchDataZMQ(ProxyDataFlow):
""" Work the same as `PrefetchData`, but faster. """
def __init__(self, ds, nr_prefetch, nr_proc=1):
super(PrefetchDataZMQ, self).__init__(ds)
self.ds = ds
self._size = ds.size()
self.nr_proc = nr_proc
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
name = "ipc://whatever-" + str(uuid.uuid1())[:6]
self.socket.bind(name)
# TODO local queue again? probably don't affect training
self.queue = Queue(maxsize=nr_prefetch)
def enque():
while True:
self.queue.put(loads(self.socket.recv(copy=False)))
self.th = Thread(target=enque)
self.th.daemon = True
self.th.start()
self.procs = [PrefetchProcessZMQ(self.ds, name)
for _ in range(self.nr_proc)]
for x in self.procs:
x.start()
def get_data(self):
for _ in range(self._size):
dp = self.queue.get()
yield dp
#print(self.queue.qsize())
def __del__(self):
logger.info("Prefetch process exiting...")
self.queue.close()
for x in self.procs:
x.terminate()
self.th.terminate()
logger.info("Prefetch process exited.")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: serialize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
#import dill
__all__ = ['loads', 'dumps']
def dumps(obj):
#return dill.dumps(obj)
return msgpack.dumps(obj, use_bin_type=True)
def loads(buf):
#return dill.loads(buf)
return msgpack.loads(buf)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment