Commit cddb713f authored by Yuxin Wu's avatar Yuxin Wu

make send_data and RemoteData useful (#202)

parent 3397e0bd
...@@ -47,6 +47,12 @@ class TestDataSpeed(ProxyDataFlow): ...@@ -47,6 +47,12 @@ class TestDataSpeed(ProxyDataFlow):
if idx == self.test_size - 1: if idx == self.test_size - 1:
break break
def start(self):
"""
Alias of start_test.
"""
self.start_test()
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
""" """
......
...@@ -10,34 +10,33 @@ except ImportError: ...@@ -10,34 +10,33 @@ except ImportError:
logger.warn("Error in 'import zmq'. remote feature won't be available") logger.warn("Error in 'import zmq'. remote feature won't be available")
__all__ = [] __all__ = []
else: else:
__all__ = ['serve_data', 'RemoteData'] __all__ = ['send_dataflow_zmq', 'RemoteDataZMQ']
from .base import DataFlow from .base import DataFlow
from .common import RepeatedData
from ..utils import logger from ..utils import logger
from ..utils.serialize import dumps, loads from ..utils.serialize import dumps, loads
def serve_data(ds, addr): def send_dataflow_zmq(df, addr, hwm=50):
""" """
Serve the DataFlow on a ZMQ socket addr. Run DataFlow and send data to a ZMQ socket addr.
It will dump and send each datapoint to this addr with a PUSH socket. It will dump and send each datapoint to this addr with a PUSH socket.
Args: Args:
ds (DataFlow): DataFlow to serve. Will infinitely loop over the DataFlow. df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket addr. addr: a ZMQ socket addr.
hwm (int): high water mark
""" """
ctx = zmq.Context() ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH) socket = ctx.socket(zmq.PUSH)
socket.set_hwm(10) socket.set_hwm(hwm)
socket.bind(addr) socket.connect(addr)
ds = RepeatedData(ds, -1)
try: try:
ds.reset_state() df.reset_state()
logger.info("Serving data at {}".format(addr)) logger.info("Serving data to {}".format(addr))
# TODO print statistics such as speed # TODO print statistics such as speed
while True: while True:
for dp in ds.get_data(): for dp in df.get_data():
socket.send(dumps(dp), copy=False) socket.send(dumps(dp), copy=False)
finally: finally:
socket.setsockopt(zmq.LINGER, 0) socket.setsockopt(zmq.LINGER, 0)
...@@ -46,39 +45,73 @@ def serve_data(ds, addr): ...@@ -46,39 +45,73 @@ def serve_data(ds, addr):
ctx.destroy(0) ctx.destroy(0)
class RemoteData(DataFlow): class RemoteDataZMQ(DataFlow):
""" Produce data from a ZMQ socket. """ """ Produce data from ZMQ PULL socket(s). """
def __init__(self, addr): def __init__(self, addr1, addr2=None):
""" """
Args: Args:
addr (str): addr of the socket to connect to. addr1,addr2 (str): addr of the socket to connect to.
Use both if you need two protocols (e.g. both IPC and TCP).
I don't think you'll ever need 3.
""" """
self._addr = addr assert addr1
self._addr1 = addr1
self._addr2 = addr2
def get_data(self): def get_data(self):
try: try:
ctx = zmq.Context() ctx = zmq.Context()
if self._addr2 is None:
socket = ctx.socket(zmq.PULL) socket = ctx.socket(zmq.PULL)
socket.connect(self._addr) socket.set_hwm(50)
socket.bind(self._addr1)
while True: while True:
dp = loads(socket.recv(copy=False)) dp = loads(socket.recv(copy=False).bytes)
yield dp
else:
socket1 = ctx.socket(zmq.PULL)
socket1.set_hwm(50)
socket1.bind(self._addr1)
socket2 = ctx.socket(zmq.PULL)
socket2.set_hwm(50)
socket2.bind(self._addr2)
poller = zmq.Poller()
poller.register(socket1, zmq.POLLIN)
poller.register(socket2, zmq.POLLIN)
while True:
evts = poller.poll()
for sock, evt in evts:
dp = loads(sock.recv(copy=False).bytes)
yield dp yield dp
finally: finally:
ctx.destroy(linger=0) ctx.destroy(linger=0)
if __name__ == '__main__': if __name__ == '__main__':
import sys from argparse import ArgumentParser
from tqdm import tqdm
from .raw import FakeData from .raw import FakeData
addr = "tcp://127.0.0.1:8877" from .common import TestDataSpeed
if sys.argv[1] == 'serve':
ds = FakeData([(128, 244, 244, 3)], 1000) """
serve_data(ds, addr) Test the multi-producer single-consumer model
"""
parser = ArgumentParser()
parser.add_argument('-t', '--task', choices=['send', 'recv'], required=True)
parser.add_argument('-a', '--addr1', required=True)
parser.add_argument('-b', '--addr2', default=None)
args = parser.parse_args()
# tcp addr like "tcp://127.0.0.1:8877"
# ipc addr like "ipc:///tmp/ipc-test"
if args.task == 'send':
# use random=True to make it slow and cpu-consuming
ds = FakeData([(128, 244, 244, 3)], 1000, random=True)
send_dataflow_zmq(ds, args.addr1)
else: else:
ds = RemoteData(addr) ds = RemoteDataZMQ(args.addr1, args.addr2)
logger.info("Each DP is 73.5MB") logger.info("Each DP is 73.5MB")
with tqdm(total=10000) as pbar: TestDataSpeed(ds).start_test()
for k in ds.get_data():
pbar.update()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment