Commit a0d60a64 authored by Yuxin Wu's avatar Yuxin Wu

Fix build and tests of zmq op (#362)

parent 05494dd6
# $File: Makefile
# $Date: Tue Oct 31 11:44:27 2017 +0800
# $Date: Tue Dec 12 18:04:22 2017 -0800
OBJ_DIR = obj
PYTHON = python
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S),Linux)
......@@ -21,15 +22,19 @@ INCLUDE_DIR += $(shell pkg-config --cflags $(LIBS))
LDFLAGS += $(shell pkg-config $(LIBS) --libs)
CXXFLAGS += $(INCLUDE_DIR)
CXXFLAGS += -Wall -Wextra
CXXFLAGS += -Wall -Wextra -Wno-unused-parameter -Wno-sign-compare
CXXFLAGS += $(DEFINES) -std=c++11 $(OPTFLAGS) -fPIC
# TODO https://github.com/tensorflow/tensorflow/issues/1569
# You may need to disable this flag if you compile tensorflow yourself with gcc>=5
CXXFLAGS += -D_GLIBCXX_USE_CXX11_ABI=0
ifneq ($(MAKECMDGOALS), clean)
TF_CXXFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
TF_LDFLAGS ?= $(shell $(PYTHON) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')
endif
CXXFLAGS += $(TF_CXXFLAGS)
LDFLAGS += $(OPTFLAGS)
LDFLAGS += -shared -fPIC
LDFLAGS += $(TF_LDFLAGS)
ifeq ($(UNAME_S),Darwin)
LDFLAGS += -Wl,-undefined -Wl,dynamic_lookup
endif
......@@ -41,7 +46,7 @@ OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o))
DEPFILES = $(OBJS:.o=.d)
# TODO what about mac?
SO = $(ccSOURCES:.cc=.so)
SO = zmq_recv_op.so
.PHONY: all clean
......
......@@ -9,10 +9,11 @@ import os
def compile():
# TODO check modtime?
include_dir = tf.sysconfig.get_include()
cxxflags = ' '.join(tf.sysconfig.get_compile_flags())
ldflags = ' '.join(tf.sysconfig.get_link_flags())
file_dir = os.path.dirname(os.path.abspath(__file__))
compile_cmd = 'INCLUDE_DIR="-isystem {}" make -C "{}"'.format(include_dir, file_dir)
compile_cmd = 'TF_CXXFLAGS="{}" TF_LDFLAGS="{}" make -C "{}"'.format(
cxxflags, ldflags, file_dir)
ret = os.system(compile_cmd)
return ret
......@@ -20,6 +21,7 @@ def compile():
# https://github.com/uber/horovod/blob/10835d25eccf4b198a23a0795edddf0896f6563d/horovod/tensorflow/mpi_ops.py#L30-L40
def get_ext_suffix():
"""Determine library extension for various versions of Python."""
return '.so' # TODO
ext_suffix = sysconfig.get_config_var('EXT_SUFFIX')
if ext_suffix:
return ext_suffix
......
......@@ -11,46 +11,46 @@ import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa
from tensorpack.user_ops.zmq_recv import ( # noqa
zmq_recv, dump_tensor_protos, to_tensor_proto)
zmq_recv, dumps_zmq_op)
from tensorpack.utils.concurrency import ( # noqa
start_proc_mask_signal,
ensure_proc_terminate)
try:
num = int(sys.argv[1])
except ValueError:
num = 2
ENDPOINT = 'ipc://test-pipe'
DATA = []
for k in range(num):
arr1 = np.random.rand(k + 10, k + 10).astype('float32')
arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8')
DATA.append([arr1, arr2])
def send():
ctx = zmq.Context()
sok = ctx.socket(zmq.PUSH)
sok.connect(ENDPOINT)
for arr1, arr2 in DATA:
t1 = to_tensor_proto(arr1)
t2 = to_tensor_proto(arr2)
t = dump_tensor_protos([t1, t2])
sok.send(t)
def recv():
sess = tf.InteractiveSession()
recv = zmq_recv(ENDPOINT, [tf.float32, tf.uint8])
print(recv)
for truth in DATA:
arr = sess.run(recv)
assert (arr[0] == truth[0]).all()
assert (arr[1] == truth[1]).all()
p = mp.Process(target=send)
p.start()
recv()
p.join()
if __name__ == '__main__':
try:
num = int(sys.argv[1])
except (ValueError, IndexError):
num = 10
DATA = []
for k in range(num):
arr1 = np.random.rand(k + 10, k + 10).astype('float32')
arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8')
DATA.append([arr1, arr2])
def send():
ctx = zmq.Context()
sok = ctx.socket(zmq.PUSH)
sok.connect(ENDPOINT)
for dp in DATA:
sok.send(dumps_zmq_op(dp))
def recv():
sess = tf.Session()
recv = zmq_recv(ENDPOINT, [tf.float32, tf.uint8])
print(recv)
for truth in DATA:
arr = sess.run(recv)
assert (arr[0] == truth[0]).all()
assert (arr[1] == truth[1]).all()
p = mp.Process(target=send)
ensure_proc_terminate(p)
start_proc_mask_signal(p)
recv()
p.join()
......@@ -32,9 +32,8 @@ struct RecvTensorList {
class ZMQConnection {
public:
ZMQConnection(std::string endpoint, int zmq_socket_type):
ZMQConnection(std::string endpoint, int zmq_socket_type, int hwm):
ctx_(1), sock_(ctx_, zmq_socket_type) {
int hwm = 100; // TODO make it an option
sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm);
sock_.bind(endpoint.c_str());
}
......
......@@ -13,7 +13,7 @@ from tensorflow.core.framework import types_pb2 as DataType
from .common import compile, get_ext_suffix
__all__ = ['zmq_recv', 'dumps_for_tfop',
__all__ = ['zmq_recv', 'dumps_zmq_op',
'dump_tensor_protos', 'to_tensor_proto']
......@@ -26,7 +26,7 @@ def build():
else:
file_dir = os.path.dirname(os.path.abspath(__file__))
recv_mod = tf.load_op_library(
os.path.join(file_dir, 'zmq_recv_op.' + get_ext_suffix()))
os.path.join(file_dir, 'zmq_recv_op' + get_ext_suffix()))
zmq_recv = recv_mod.zmq_recv
......@@ -51,6 +51,7 @@ def to_tensor_proto(arr):
Args:
arr: numpy.ndarray. only supports common numerical types
"""
assert isinstance(arr, np.ndarray), type(arr)
dtype = _DTYPE_DICT[arr.dtype]
ret = TensorProto()
......@@ -100,9 +101,15 @@ def dump_tensor_protos(protos):
return s
def dumps_for_tfop(dp):
def dumps_zmq_op(dp):
"""
Dump a datapoint (list of nparray) into a format that the ZMQRecv op in tensorpack would accept.
Args:
dp: list of nparray
Returns:
a binary string
"""
protos = [to_tensor_proto(arr) for arr in dp]
return dump_tensor_protos(protos)
......@@ -16,11 +16,12 @@ REGISTER_OP("ZMQRecv")
.Output("output: types")
.Attr("end_point: string")
.Attr("types: list(type) >= 1")
.Attr("hwm: int >= 1 = 100")
.SetShapeFn(shape_inference::UnknownShape)
.SetIsStateful()
.Doc(R"doc(
Receive a serialized list of Tensors from a ZMQ socket.
The serialization format is a tensorpack custom format.
Receive a list of Tensors from a ZMQ socket.
The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc");
......@@ -32,7 +33,10 @@ class ZMQRecvOp: public OpKernel {
string endpoint;
OP_REQUIRES_OK(context, context->GetAttr("end_point", &endpoint));
conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL));
int hwm;
OP_REQUIRES_OK(context, context->GetAttr("hwm", &hwm));
conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL, hwm));
}
void Compute(OpKernelContext* ctx) override {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment