Commit 6afeb544 authored by Yuxin Wu's avatar Yuxin Wu

[ZMQ] better toolchain for zmq send/recv data

parent 4c02f009
...@@ -86,3 +86,4 @@ target/ ...@@ -86,3 +86,4 @@ target/
*.dat *.dat
.idea/ .idea/
*.diff
...@@ -4,10 +4,12 @@ ...@@ -4,10 +4,12 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import time import time
import tqdm
from collections import deque from collections import deque
from .base import DataFlow, DataFlowReentrantGuard from .base import DataFlow, DataFlowReentrantGuard
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm from ..utils.utils import get_tqdm_kwargs
from ..utils.serialize import dumps, loads from ..utils.serialize import dumps, loads
try: try:
import zmq import zmq
...@@ -18,7 +20,7 @@ else: ...@@ -18,7 +20,7 @@ else:
__all__ = ['send_dataflow_zmq', 'RemoteDataZMQ'] __all__ = ['send_dataflow_zmq', 'RemoteDataZMQ']
def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None): def send_dataflow_zmq(df, addr, hwm=50, format=None):
""" """
Run DataFlow and send data to a ZMQ socket addr. Run DataFlow and send data to a ZMQ socket addr.
It will __connect__ to this addr, It will __connect__ to this addr,
...@@ -47,16 +49,25 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None): ...@@ -47,16 +49,25 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None):
try: try:
df.reset_state() df.reset_state()
logger.info("Serving data to {} ...".format(addr)) logger.info("Serving data to {} ...".format(addr))
q = deque(maxlen=print_interval) INTERVAL = 200
with get_tqdm(total=0) as pbar: q = deque(maxlen=INTERVAL)
while True:
try:
total = df.size()
except NotImplementedError:
total = 0
tqdm_args = get_tqdm_kwargs(leave=True)
tqdm_args['bar_format'] = tqdm_args['bar_format'] + "{postfix}"
while True:
with tqdm.trange(total, **tqdm_args) as pbar:
for dp in df.get_data(): for dp in df.get_data():
start = time.time() start = time.time()
socket.send(dump_fn(dp), copy=False) socket.send(dump_fn(dp), copy=False)
q.append(time.time() - start) q.append(time.time() - start)
pbar.update(1) pbar.update(1)
if pbar.n % print_interval == 0: if pbar.n % INTERVAL == 0:
pbar.write("Avg send time @{}: {}".format(pbar.n, sum(q) / len(q))) avg = "{:.3f}".format(sum(q) / len(q))
pbar.set_postfix({'AvgSendLat': avg})
finally: finally:
socket.setsockopt(zmq.LINGER, 0) socket.setsockopt(zmq.LINGER, 0)
socket.close() socket.close()
......
...@@ -34,6 +34,7 @@ class InputDesc( ...@@ -34,6 +34,7 @@ class InputDesc(
name (str): name (str):
""" """
shape = tuple(shape) # has to be tuple for "self" to be hashable shape = tuple(shape) # has to be tuple for "self" to be hashable
assert isinstance(type, tf.DType), type
self = super(InputDesc, cls).__new__(cls, type, shape, name) self = super(InputDesc, cls).__new__(cls, type, shape, name)
self._cached_placeholder = None self._cached_placeholder = None
return self return self
......
...@@ -24,13 +24,11 @@ from ..utils.develop import log_deprecated ...@@ -24,13 +24,11 @@ from ..utils.develop import log_deprecated
from ..callbacks.base import Callback, CallbackFactory from ..callbacks.base import Callback, CallbackFactory
from ..callbacks.graph import RunOp from ..callbacks.graph import RunOp
__all__ = ['PlaceholderInput', 'FeedInput', __all__ = ['PlaceholderInput', 'FeedInput', 'FeedfreeInput',
'FeedfreeInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
'DummyConstantInput', 'TensorInput', 'DummyConstantInput', 'TensorInput',
'TFDatasetInput', 'ZMQInput', 'TFDatasetInput',
'StagingInputWrapper', 'StagingInputWrapper', 'StagingInput']
'StagingInput']
def _get_reset_callback(df): def _get_reset_callback(df):
...@@ -382,29 +380,36 @@ class DummyConstantInput(TensorInput): ...@@ -382,29 +380,36 @@ class DummyConstantInput(TensorInput):
class ZMQInput(TensorInput): class ZMQInput(TensorInput):
""" """
Not well implemented yet. Don't use. Recv tensors from a ZMQ endpoint.
It works with :meth:`dataflow.remote.send_dataflow_zmq(format='zmq_op')`.
""" """
def __init__(self, endpoint, hwm): def __init__(self, end_point, hwm):
self._endpoint = endpoint """
Args:
from tensorpack.user_ops import zmq_recv end_point (str):
hwm (int):
"""
self._end_point = end_point
self._hwm = int(hwm)
def fn(): def fn():
ret = zmq_recv( ret = self._zmq_recv_socket.recv()
self._endpoint, [x.dtype for x in self.inputs_desc], assert len(ret) == len(self._desc)
hwm=hwm) for qv, v in zip(ret, self._desc):
if isinstance(ret, tf.Tensor):
ret = [ret]
assert len(ret) == len(self.inputs_desc)
for qv, v in zip(ret, self.inputs_desc):
qv.set_shape(v.shape) qv.set_shape(v.shape)
return ret return ret
super(ZMQInput, self).__init__(fn) super(ZMQInput, self).__init__(fn)
def _setup(self, inputs_desc): def _setup(self, inputs_desc):
self.inputs_desc = inputs_desc assert len(inputs_desc) > 0, \
assert len(self.inputs_desc) > 0, \
"ZMQInput has to be used with InputDesc!" "ZMQInput has to be used with InputDesc!"
self._desc = inputs_desc
from ..user_ops import zmq_recv
self._zmq_recv_socket = zmq_recv.ZMQSocket(
self._end_point,
[x.type for x in inputs_desc],
self._hwm)
class TFDatasetInput(FeedfreeInput): class TFDatasetInput(FeedfreeInput):
......
# $File: Makefile # $File: Makefile
# $Date: Tue Dec 12 22:27:38 2017 -0800 # $Date: Thu Dec 14 18:03:41 2017 -0800
OBJ_DIR = obj OBJ_DIR = obj
PYTHON = python PYTHON = python
...@@ -16,25 +16,24 @@ OPTFLAGS ?= -O3 -march=native ...@@ -16,25 +16,24 @@ OPTFLAGS ?= -O3 -march=native
#OPTFLAGS ?= -g3 -fsanitize=address,undefined -O2 -lasan #OPTFLAGS ?= -g3 -fsanitize=address,undefined -O2 -lasan
#OPTFLAGS ?= -g3 -fsanitize=leak -O2 -lubsan #OPTFLAGS ?= -g3 -fsanitize=leak -O2 -lubsan
# libraries: TF preceeds others, so g++ looks for protobuf among TF headers
ifneq ($(MAKECMDGOALS), clean)
TF_CXXFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
TF_LDFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')
endif
CXXFLAGS += $(TF_CXXFLAGS)
LDFLAGS += $(TF_LDFLAGS)
# extra packages from pkg-config # extra packages from pkg-config
LIBS = libzmq LIBS = libzmq
INCLUDE_DIR += $(shell pkg-config --cflags $(LIBS)) CXXFLAGS += $(shell pkg-config --cflags $(LIBS))
LDFLAGS += $(shell pkg-config $(LIBS) --libs) LDFLAGS += $(shell pkg-config $(LIBS) --libs)
CXXFLAGS += $(INCLUDE_DIR)
CXXFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-sign-compare CXXFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-sign-compare
CXXFLAGS += $(DEFINES) -std=c++11 $(OPTFLAGS) -fPIC CXXFLAGS += $(DEFINES) -std=c++11 $(OPTFLAGS) -fPIC
ifneq ($(MAKECMDGOALS), clean)
TF_CXXFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
TF_LDFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')
endif
CXXFLAGS += $(TF_CXXFLAGS)
LDFLAGS += $(OPTFLAGS) LDFLAGS += $(OPTFLAGS)
LDFLAGS += -shared -fPIC LDFLAGS += -shared -fPIC
LDFLAGS += $(TF_LDFLAGS)
ifeq ($(UNAME_S),Darwin) ifeq ($(UNAME_S),Darwin)
LDFLAGS += -Wl,-undefined -Wl,dynamic_lookup LDFLAGS += -Wl,-undefined -Wl,dynamic_lookup
endif endif
......
...@@ -12,7 +12,7 @@ import numpy as np ...@@ -12,7 +12,7 @@ import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa import tensorflow as tf # noqa
from tensorpack.user_ops.zmq_recv import ( # noqa from tensorpack.user_ops.zmq_recv import ( # noqa
ZMQRecv, dumps_zmq_op) ZMQSocket, dumps_zmq_op)
from tensorpack.utils.concurrency import ( # noqa from tensorpack.utils.concurrency import ( # noqa
start_proc_mask_signal, start_proc_mask_signal,
ensure_proc_terminate) ensure_proc_terminate)
...@@ -68,7 +68,7 @@ if __name__ == '__main__': ...@@ -68,7 +68,7 @@ if __name__ == '__main__':
start_proc_mask_signal(p) start_proc_mask_signal(p)
sess = tf.Session() sess = tf.Session()
recv = ZMQRecv(ENDPOINT, [tf.float32, tf.uint8]).recv() recv = ZMQSocket(ENDPOINT, [tf.float32, tf.uint8]).recv()
print(recv) print(recv)
for truth in DATA: for truth in DATA:
...@@ -87,7 +87,7 @@ if __name__ == '__main__': ...@@ -87,7 +87,7 @@ if __name__ == '__main__':
start_proc_mask_signal(p) start_proc_mask_signal(p)
sess = tf.Session() sess = tf.Session()
zmqsock = ZMQRecv(ENDPOINT, [tf.float32, tf.uint8], hwm=1) zmqsock = ZMQSocket(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
recv1 = zmqsock.recv() recv1 = zmqsock.recv()
recv2 = zmqsock.recv() recv2 = zmqsock.recv()
print(recv1, recv2) print(recv1, recv2)
......
...@@ -13,8 +13,7 @@ from tensorflow.core.framework import types_pb2 as DT ...@@ -13,8 +13,7 @@ from tensorflow.core.framework import types_pb2 as DT
from .common import compile, get_ext_suffix from .common import compile, get_ext_suffix
__all__ = ['dumps_zmq_op', 'ZMQRecv', __all__ = ['dumps_zmq_op', 'ZMQSocket']
'dump_tensor_protos', 'to_tensor_proto']
_zmq_recv_mod = None _zmq_recv_mod = None
...@@ -36,9 +35,10 @@ def try_build(): ...@@ -36,9 +35,10 @@ def try_build():
try_build() try_build()
class ZMQRecv(object): class ZMQSocket(object):
def __init__(self, end_point, types, hwm=None, name=None): def __init__(self, end_point, types, hwm=None, bind=True, name=None):
self._types = types self._types = types
assert isinstance(bind, bool), bind
if name is None: if name is None:
self._name = (tf.get_default_graph() self._name = (tf.get_default_graph()
...@@ -47,7 +47,7 @@ class ZMQRecv(object): ...@@ -47,7 +47,7 @@ class ZMQRecv(object):
self._name = name self._name = name
self._zmq_handle = _zmq_recv_mod.zmq_connection( self._zmq_handle = _zmq_recv_mod.zmq_connection(
end_point, hwm, shared_name=self._name) end_point, hwm, bind=bind, shared_name=self._name)
@property @property
def name(self): def name(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment