Commit a0d60a64 authored by Yuxin Wu's avatar Yuxin Wu

Fix build and tests of zmq op (#362)

parent 05494dd6
# $File: Makefile # $File: Makefile
# $Date: Tue Oct 31 11:44:27 2017 +0800 # $Date: Tue Dec 12 18:04:22 2017 -0800
OBJ_DIR = obj OBJ_DIR = obj
PYTHON = python
UNAME_S := $(shell uname -s) UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S),Linux) ifeq ($(UNAME_S),Linux)
...@@ -21,15 +22,19 @@ INCLUDE_DIR += $(shell pkg-config --cflags $(LIBS)) ...@@ -21,15 +22,19 @@ INCLUDE_DIR += $(shell pkg-config --cflags $(LIBS))
LDFLAGS += $(shell pkg-config $(LIBS) --libs) LDFLAGS += $(shell pkg-config $(LIBS) --libs)
CXXFLAGS += $(INCLUDE_DIR) CXXFLAGS += $(INCLUDE_DIR)
CXXFLAGS += -Wall -Wextra CXXFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-sign-compare
CXXFLAGS += $(DEFINES) -std=c++11 $(OPTFLAGS) -fPIC CXXFLAGS += $(DEFINES) -std=c++11 $(OPTFLAGS) -fPIC
# TODO https://github.com/tensorflow/tensorflow/issues/1569 ifneq ($(MAKECMDGOALS), clean)
# You may need to disable this flag if you compile tensorflow yourself with gcc>=5 TF_CXXFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
CXXFLAGS += -D_GLIBCXX_USE_CXX11_ABI=0 TF_LDFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')
endif
CXXFLAGS += $(TF_CXXFLAGS)
LDFLAGS += $(OPTFLAGS) LDFLAGS += $(OPTFLAGS)
LDFLAGS += -shared -fPIC LDFLAGS += -shared -fPIC
LDFLAGS += $(TF_LDFLAGS)
ifeq ($(UNAME_S),Darwin) ifeq ($(UNAME_S),Darwin)
LDFLAGS += -Wl,-undefined -Wl,dynamic_lookup LDFLAGS += -Wl,-undefined -Wl,dynamic_lookup
endif endif
...@@ -41,7 +46,7 @@ OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o)) ...@@ -41,7 +46,7 @@ OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o))
DEPFILES = $(OBJS:.o=.d) DEPFILES = $(OBJS:.o=.d)
# TODO what about mac? # TODO what about mac?
SO = $(ccSOURCES:.cc=.so) SO = zmq_recv_op.so
.PHONY: all clean .PHONY: all clean
......
...@@ -9,10 +9,11 @@ import os ...@@ -9,10 +9,11 @@ import os
def compile(): def compile():
# TODO check modtime? cxxflags = ' '.join(tf.sysconfig.get_compile_flags())
include_dir = tf.sysconfig.get_include() ldflags = ' '.join(tf.sysconfig.get_link_flags())
file_dir = os.path.dirname(os.path.abspath(__file__)) file_dir = os.path.dirname(os.path.abspath(__file__))
compile_cmd = 'INCLUDE_DIR="-isystem {}" make -C "{}"'.format(include_dir, file_dir) compile_cmd = 'TF_CXXFLAGS="{}" TF_LDFLAGS="{}" make -C "{}"'.format(
cxxflags, ldflags, file_dir)
ret = os.system(compile_cmd) ret = os.system(compile_cmd)
return ret return ret
...@@ -20,6 +21,7 @@ def compile(): ...@@ -20,6 +21,7 @@ def compile():
# https://github.com/uber/horovod/blob/10835d25eccf4b198a23a0795edddf0896f6563d/horovod/tensorflow/mpi_ops.py#L30-L40 # https://github.com/uber/horovod/blob/10835d25eccf4b198a23a0795edddf0896f6563d/horovod/tensorflow/mpi_ops.py#L30-L40
def get_ext_suffix(): def get_ext_suffix():
"""Determine library extension for various versions of Python.""" """Determine library extension for various versions of Python."""
return '.so' # TODO
ext_suffix = sysconfig.get_config_var('EXT_SUFFIX') ext_suffix = sysconfig.get_config_var('EXT_SUFFIX')
if ext_suffix: if ext_suffix:
return ext_suffix return ext_suffix
......
...@@ -11,46 +11,46 @@ import numpy as np ...@@ -11,46 +11,46 @@ import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa import tensorflow as tf # noqa
from tensorpack.user_ops.zmq_recv import ( # noqa from tensorpack.user_ops.zmq_recv import ( # noqa
zmq_recv, dump_tensor_protos, to_tensor_proto) zmq_recv, dumps_zmq_op)
from tensorpack.utils.concurrency import ( # noqa
start_proc_mask_signal,
ensure_proc_terminate)
try:
num = int(sys.argv[1])
except ValueError:
num = 2
ENDPOINT = 'ipc://test-pipe' ENDPOINT = 'ipc://test-pipe'
DATA = [] if __name__ == '__main__':
for k in range(num): try:
arr1 = np.random.rand(k + 10, k + 10).astype('float32') num = int(sys.argv[1])
arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8') except (ValueError, IndexError):
DATA.append([arr1, arr2]) num = 10
DATA = []
def send(): for k in range(num):
ctx = zmq.Context() arr1 = np.random.rand(k + 10, k + 10).astype('float32')
sok = ctx.socket(zmq.PUSH) arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8')
sok.connect(ENDPOINT) DATA.append([arr1, arr2])
for arr1, arr2 in DATA: def send():
t1 = to_tensor_proto(arr1) ctx = zmq.Context()
t2 = to_tensor_proto(arr2) sok = ctx.socket(zmq.PUSH)
t = dump_tensor_protos([t1, t2]) sok.connect(ENDPOINT)
sok.send(t)
for dp in DATA:
sok.send(dumps_zmq_op(dp))
def recv():
sess = tf.InteractiveSession() def recv():
recv = zmq_recv(ENDPOINT, [tf.float32, tf.uint8]) sess = tf.Session()
print(recv) recv = zmq_recv(ENDPOINT, [tf.float32, tf.uint8])
print(recv)
for truth in DATA:
arr = sess.run(recv) for truth in DATA:
assert (arr[0] == truth[0]).all() arr = sess.run(recv)
assert (arr[1] == truth[1]).all() assert (arr[0] == truth[0]).all()
assert (arr[1] == truth[1]).all()
p = mp.Process(target=send) p = mp.Process(target=send)
p.start() ensure_proc_terminate(p)
recv() start_proc_mask_signal(p)
p.join() recv()
p.join()
...@@ -32,9 +32,8 @@ struct RecvTensorList { ...@@ -32,9 +32,8 @@ struct RecvTensorList {
class ZMQConnection { class ZMQConnection {
public: public:
ZMQConnection(std::string endpoint, int zmq_socket_type): ZMQConnection(std::string endpoint, int zmq_socket_type, int hwm):
ctx_(1), sock_(ctx_, zmq_socket_type) { ctx_(1), sock_(ctx_, zmq_socket_type) {
int hwm = 100; // TODO make it an option
sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm); sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm);
sock_.bind(endpoint.c_str()); sock_.bind(endpoint.c_str());
} }
......
...@@ -13,7 +13,7 @@ from tensorflow.core.framework import types_pb2 as DataType ...@@ -13,7 +13,7 @@ from tensorflow.core.framework import types_pb2 as DataType
from .common import compile, get_ext_suffix from .common import compile, get_ext_suffix
__all__ = ['zmq_recv', 'dumps_for_tfop', __all__ = ['zmq_recv', 'dumps_zmq_op',
'dump_tensor_protos', 'to_tensor_proto'] 'dump_tensor_protos', 'to_tensor_proto']
...@@ -26,7 +26,7 @@ def build(): ...@@ -26,7 +26,7 @@ def build():
else: else:
file_dir = os.path.dirname(os.path.abspath(__file__)) file_dir = os.path.dirname(os.path.abspath(__file__))
recv_mod = tf.load_op_library( recv_mod = tf.load_op_library(
os.path.join(file_dir, 'zmq_recv_op.' + get_ext_suffix())) os.path.join(file_dir, 'zmq_recv_op' + get_ext_suffix()))
zmq_recv = recv_mod.zmq_recv zmq_recv = recv_mod.zmq_recv
...@@ -51,6 +51,7 @@ def to_tensor_proto(arr): ...@@ -51,6 +51,7 @@ def to_tensor_proto(arr):
Args: Args:
arr: numpy.ndarray. only supports common numerical types arr: numpy.ndarray. only supports common numerical types
""" """
assert isinstance(arr, np.ndarray), type(arr)
dtype = _DTYPE_DICT[arr.dtype] dtype = _DTYPE_DICT[arr.dtype]
ret = TensorProto() ret = TensorProto()
...@@ -100,9 +101,15 @@ def dump_tensor_protos(protos): ...@@ -100,9 +101,15 @@ def dump_tensor_protos(protos):
return s return s
def dumps_for_tfop(dp): def dumps_zmq_op(dp):
""" """
Dump a datapoint (list of nparray) into a format that the ZMQRecv op in tensorpack would accept. Dump a datapoint (list of nparray) into a format that the ZMQRecv op in tensorpack would accept.
Args:
dp: list of nparray
Returns:
a binary string
""" """
protos = [to_tensor_proto(arr) for arr in dp] protos = [to_tensor_proto(arr) for arr in dp]
return dump_tensor_protos(protos) return dump_tensor_protos(protos)
...@@ -16,11 +16,12 @@ REGISTER_OP("ZMQRecv") ...@@ -16,11 +16,12 @@ REGISTER_OP("ZMQRecv")
.Output("output: types") .Output("output: types")
.Attr("end_point: string") .Attr("end_point: string")
.Attr("types: list(type) >= 1") .Attr("types: list(type) >= 1")
.Attr("hwm: int >= 1 = 100")
.SetShapeFn(shape_inference::UnknownShape) .SetShapeFn(shape_inference::UnknownShape)
.SetIsStateful() .SetIsStateful()
.Doc(R"doc( .Doc(R"doc(
Receive a serialized list of Tensors from a ZMQ socket. Receive a list of Tensors from a ZMQ socket.
The serialization format is a tensorpack custom format. The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc"); )doc");
...@@ -32,7 +33,10 @@ class ZMQRecvOp: public OpKernel { ...@@ -32,7 +33,10 @@ class ZMQRecvOp: public OpKernel {
string endpoint; string endpoint;
OP_REQUIRES_OK(context, context->GetAttr("end_point", &endpoint)); OP_REQUIRES_OK(context, context->GetAttr("end_point", &endpoint));
conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL));
int hwm;
OP_REQUIRES_OK(context, context->GetAttr("hwm", &hwm));
conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL, hwm));
} }
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment