Commit 6afeb544 authored by Yuxin Wu's avatar Yuxin Wu

[ZMQ] better toolchain for zmq send/recv data

parent 4c02f009
......@@ -86,3 +86,4 @@ target/
*.dat
.idea/
*.diff
......@@ -4,10 +4,12 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import time
import tqdm
from collections import deque
from .base import DataFlow, DataFlowReentrantGuard
from ..utils import logger
from ..utils.utils import get_tqdm
from ..utils.utils import get_tqdm_kwargs
from ..utils.serialize import dumps, loads
try:
import zmq
......@@ -18,7 +20,7 @@ else:
__all__ = ['send_dataflow_zmq', 'RemoteDataZMQ']
def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None):
def send_dataflow_zmq(df, addr, hwm=50, format=None):
"""
Run DataFlow and send data to a ZMQ socket addr.
It will __connect__ to this addr,
......@@ -47,16 +49,25 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None):
try:
df.reset_state()
logger.info("Serving data to {} ...".format(addr))
q = deque(maxlen=print_interval)
with get_tqdm(total=0) as pbar:
INTERVAL = 200
q = deque(maxlen=INTERVAL)
try:
total = df.size()
except NotImplementedError:
total = 0
tqdm_args = get_tqdm_kwargs(leave=True)
tqdm_args['bar_format'] = tqdm_args['bar_format'] + "{postfix}"
while True:
with tqdm.trange(total, **tqdm_args) as pbar:
for dp in df.get_data():
start = time.time()
socket.send(dump_fn(dp), copy=False)
q.append(time.time() - start)
pbar.update(1)
if pbar.n % print_interval == 0:
pbar.write("Avg send time @{}: {}".format(pbar.n, sum(q) / len(q)))
if pbar.n % INTERVAL == 0:
avg = "{:.3f}".format(sum(q) / len(q))
pbar.set_postfix({'AvgSendLat': avg})
finally:
socket.setsockopt(zmq.LINGER, 0)
socket.close()
......
......@@ -34,6 +34,7 @@ class InputDesc(
name (str):
"""
shape = tuple(shape) # has to be tuple for "self" to be hashable
assert isinstance(type, tf.DType), type
self = super(InputDesc, cls).__new__(cls, type, shape, name)
self._cached_placeholder = None
return self
......
......@@ -24,13 +24,11 @@ from ..utils.develop import log_deprecated
from ..callbacks.base import Callback, CallbackFactory
from ..callbacks.graph import RunOp
__all__ = ['PlaceholderInput', 'FeedInput',
'FeedfreeInput',
__all__ = ['PlaceholderInput', 'FeedInput', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput',
'DummyConstantInput', 'TensorInput',
'TFDatasetInput',
'StagingInputWrapper',
'StagingInput']
'ZMQInput', 'TFDatasetInput',
'StagingInputWrapper', 'StagingInput']
def _get_reset_callback(df):
......@@ -382,29 +380,36 @@ class DummyConstantInput(TensorInput):
class ZMQInput(TensorInput):
"""
Not well implemented yet. Don't use.
Recv tensors from a ZMQ endpoint.
It works with :meth:`dataflow.remote.send_dataflow_zmq(format='zmq_op')`.
"""
def __init__(self, endpoint, hwm):
self._endpoint = endpoint
from tensorpack.user_ops import zmq_recv
def __init__(self, end_point, hwm):
"""
Args:
end_point (str):
hwm (int):
"""
self._end_point = end_point
self._hwm = int(hwm)
def fn():
ret = zmq_recv(
self._endpoint, [x.dtype for x in self.inputs_desc],
hwm=hwm)
if isinstance(ret, tf.Tensor):
ret = [ret]
assert len(ret) == len(self.inputs_desc)
for qv, v in zip(ret, self.inputs_desc):
ret = self._zmq_recv_socket.recv()
assert len(ret) == len(self._desc)
for qv, v in zip(ret, self._desc):
qv.set_shape(v.shape)
return ret
super(ZMQInput, self).__init__(fn)
def _setup(self, inputs_desc):
self.inputs_desc = inputs_desc
assert len(self.inputs_desc) > 0, \
assert len(inputs_desc) > 0, \
"ZMQInput has to be used with InputDesc!"
self._desc = inputs_desc
from ..user_ops import zmq_recv
self._zmq_recv_socket = zmq_recv.ZMQSocket(
self._end_point,
[x.type for x in inputs_desc],
self._hwm)
class TFDatasetInput(FeedfreeInput):
......
# $File: Makefile
# $Date: Tue Dec 12 22:27:38 2017 -0800
# $Date: Thu Dec 14 18:03:41 2017 -0800
OBJ_DIR = obj
PYTHON = python
......@@ -16,25 +16,24 @@ OPTFLAGS ?= -O3 -march=native
#OPTFLAGS ?= -g3 -fsanitize=address,undefined -O2 -lasan
#OPTFLAGS ?= -g3 -fsanitize=leak -O2 -lubsan
# libraries: TF preceeds others, so g++ looks for protobuf among TF headers
ifneq ($(MAKECMDGOALS), clean)
TF_CXXFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
TF_LDFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')
endif
CXXFLAGS += $(TF_CXXFLAGS)
LDFLAGS += $(TF_LDFLAGS)
# extra packages from pkg-config
LIBS = libzmq
INCLUDE_DIR += $(shell pkg-config --cflags $(LIBS))
CXXFLAGS += $(shell pkg-config --cflags $(LIBS))
LDFLAGS += $(shell pkg-config $(LIBS) --libs)
CXXFLAGS += $(INCLUDE_DIR)
CXXFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-sign-compare
CXXFLAGS += $(DEFINES) -std=c++11 $(OPTFLAGS) -fPIC
ifneq ($(MAKECMDGOALS), clean)
TF_CXXFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
TF_LDFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')
endif
CXXFLAGS += $(TF_CXXFLAGS)
LDFLAGS += $(OPTFLAGS)
LDFLAGS += -shared -fPIC
LDFLAGS += $(TF_LDFLAGS)
ifeq ($(UNAME_S),Darwin)
LDFLAGS += -Wl,-undefined -Wl,dynamic_lookup
endif
......
......@@ -12,7 +12,7 @@ import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa
from tensorpack.user_ops.zmq_recv import ( # noqa
ZMQRecv, dumps_zmq_op)
ZMQSocket, dumps_zmq_op)
from tensorpack.utils.concurrency import ( # noqa
start_proc_mask_signal,
ensure_proc_terminate)
......@@ -68,7 +68,7 @@ if __name__ == '__main__':
start_proc_mask_signal(p)
sess = tf.Session()
recv = ZMQRecv(ENDPOINT, [tf.float32, tf.uint8]).recv()
recv = ZMQSocket(ENDPOINT, [tf.float32, tf.uint8]).recv()
print(recv)
for truth in DATA:
......@@ -87,7 +87,7 @@ if __name__ == '__main__':
start_proc_mask_signal(p)
sess = tf.Session()
zmqsock = ZMQRecv(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
zmqsock = ZMQSocket(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
recv1 = zmqsock.recv()
recv2 = zmqsock.recv()
print(recv1, recv2)
......
......@@ -13,8 +13,7 @@ from tensorflow.core.framework import types_pb2 as DT
from .common import compile, get_ext_suffix
__all__ = ['dumps_zmq_op', 'ZMQRecv',
'dump_tensor_protos', 'to_tensor_proto']
__all__ = ['dumps_zmq_op', 'ZMQSocket']
_zmq_recv_mod = None
......@@ -36,9 +35,10 @@ def try_build():
try_build()
class ZMQRecv(object):
def __init__(self, end_point, types, hwm=None, name=None):
class ZMQSocket(object):
def __init__(self, end_point, types, hwm=None, bind=True, name=None):
self._types = types
assert isinstance(bind, bool), bind
if name is None:
self._name = (tf.get_default_graph()
......@@ -47,7 +47,7 @@ class ZMQRecv(object):
self._name = name
self._zmq_handle = _zmq_recv_mod.zmq_connection(
end_point, hwm, shared_name=self._name)
end_point, hwm, bind=bind, shared_name=self._name)
@property
def name(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment