Commit 2da6f9ed authored by Yuxin Wu's avatar Yuxin Wu

[ZMQ] correct so name; use int64; support scalar; (#362)

parent 65c8b239
...@@ -384,13 +384,15 @@ class ZMQInput(TensorInput): ...@@ -384,13 +384,15 @@ class ZMQInput(TensorInput):
""" """
Not well implemented yet. Don't use. Not well implemented yet. Don't use.
""" """
def __init__(self, endpoint): def __init__(self, endpoint, hwm):
self._endpoint = endpoint self._endpoint = endpoint
from tensorpack.user_ops import zmq_recv from tensorpack.user_ops import zmq_recv
def fn(): def fn():
ret = zmq_recv(self._endpoint, [x.dtype for x in self.inputs_desc]) ret = zmq_recv(
self._endpoint, [x.dtype for x in self.inputs_desc],
hwm=hwm)
if isinstance(ret, tf.Tensor): if isinstance(ret, tf.Tensor):
ret = [ret] ret = [ret]
assert len(ret) == len(self.inputs_desc) assert len(ret) == len(self.inputs_desc)
......
# $File: Makefile # $File: Makefile
# $Date: Tue Dec 12 18:04:22 2017 -0800 # $Date: Tue Dec 12 22:27:38 2017 -0800
OBJ_DIR = obj OBJ_DIR = obj
PYTHON = python PYTHON = python
...@@ -45,8 +45,8 @@ ccSOURCES = $(shell find $(SRCDIRS) -name "*.cc" | sed 's/^\.\///g') ...@@ -45,8 +45,8 @@ ccSOURCES = $(shell find $(SRCDIRS) -name "*.cc" | sed 's/^\.\///g')
OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o)) OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o))
DEPFILES = $(OBJS:.o=.d) DEPFILES = $(OBJS:.o=.d)
# TODO what about mac? EXT_SUFFIX ?= $(shell $(PYTHON) -c 'import sysconfig; print(sysconfig.get_config_var("EXT_SUFFIX"))')
SO = zmq_recv_op.so SO = zmq_recv_op$(EXT_SUFFIX)
.PHONY: all clean .PHONY: all clean
...@@ -56,7 +56,7 @@ ifneq ($(MAKECMDGOALS), clean) ...@@ -56,7 +56,7 @@ ifneq ($(MAKECMDGOALS), clean)
sinclude $(DEPFILES) sinclude $(DEPFILES)
endif endif
%.so: $(OBJ_DIR)/%.o %$(EXT_SUFFIX): $(OBJ_DIR)/%.o
@echo "Linking $@ ..." @echo "Linking $@ ..."
@$(CXX) $^ -o $@ $(LDFLAGS) @$(CXX) $^ -o $@ $(LDFLAGS)
@echo "done." @echo "done."
......
...@@ -2,26 +2,16 @@ ...@@ -2,26 +2,16 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: common.py # File: common.py
from __future__ import print_function
import sysconfig import sysconfig
import tensorflow as tf import tensorflow as tf
import os import os
from ..utils import logger
def compile():
cxxflags = ' '.join(tf.sysconfig.get_compile_flags())
ldflags = ' '.join(tf.sysconfig.get_link_flags())
file_dir = os.path.dirname(os.path.abspath(__file__))
compile_cmd = 'TF_CXXFLAGS="{}" TF_LDFLAGS="{}" make -C "{}"'.format(
cxxflags, ldflags, file_dir)
ret = os.system(compile_cmd)
return ret
# https://github.com/uber/horovod/blob/10835d25eccf4b198a23a0795edddf0896f6563d/horovod/tensorflow/mpi_ops.py#L30-L40 # https://github.com/uber/horovod/blob/10835d25eccf4b198a23a0795edddf0896f6563d/horovod/tensorflow/mpi_ops.py#L30-L40
def get_ext_suffix(): def get_ext_suffix():
"""Determine library extension for various versions of Python.""" """Determine library extension for various versions of Python."""
return '.so' # TODO
ext_suffix = sysconfig.get_config_var('EXT_SUFFIX') ext_suffix = sysconfig.get_config_var('EXT_SUFFIX')
if ext_suffix: if ext_suffix:
return ext_suffix return ext_suffix
...@@ -33,5 +23,17 @@ def get_ext_suffix(): ...@@ -33,5 +23,17 @@ def get_ext_suffix():
return '.so' return '.so'
def compile():
cxxflags = ' '.join(tf.sysconfig.get_compile_flags())
ldflags = ' '.join(tf.sysconfig.get_link_flags())
ext_suffix = get_ext_suffix()
file_dir = os.path.dirname(os.path.abspath(__file__))
compile_cmd = 'TF_CXXFLAGS="{}" TF_LDFLAGS="{}" EXT_SUFFIX="{}" make -C "{}"'.format(
cxxflags, ldflags, ext_suffix, file_dir)
logger.info("Compile user_ops by command " + compile_cmd + ' ...')
ret = os.system(compile_cmd)
return ret
if __name__ == '__main__': if __name__ == '__main__':
compile() compile()
...@@ -38,11 +38,18 @@ def random_array(num): ...@@ -38,11 +38,18 @@ def random_array(num):
ret = [] ret = []
for k in range(num): for k in range(num):
arr1 = np.random.rand(k + 10, k + 10).astype('float32') arr1 = np.random.rand(k + 10, k + 10).astype('float32')
# arr1 = 3.0
arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8') arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8')
ret.append([arr1, arr2]) ret.append([arr1, arr2])
return ret return ret
def constant_array(num):
arr = np.ones((30, 30)).astype('float32')
arr2 = np.ones((3, 3)).astype('uint8')
return [[arr, arr2]] * num
def hash_dp(dp): def hash_dp(dp):
return sum([k.sum() for k in dp]) return sum([k.sum() for k in dp])
...@@ -50,7 +57,7 @@ def hash_dp(dp): ...@@ -50,7 +57,7 @@ def hash_dp(dp):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', default='basic', parser.add_argument('--task', default='basic',
choices=['basic', 'tworecv']) choices=['basic', 'tworecv', 'send'])
parser.add_argument('-n', '--num', type=int, default=10) parser.add_argument('-n', '--num', type=int, default=10)
args = parser.parse_args() args = parser.parse_args()
...@@ -68,10 +75,10 @@ if __name__ == '__main__': ...@@ -68,10 +75,10 @@ if __name__ == '__main__':
arr = sess.run(recv) arr = sess.run(recv)
assert (arr[0] == truth[0]).all() assert (arr[0] == truth[0]).all()
assert (arr[1] == truth[1]).all() assert (arr[1] == truth[1]).all()
elif args.task == 'send':
p.join() DATA = random_array(args.num)
send(DATA)
if args.task == 'tworecv': elif args.task == 'tworecv':
DATA = random_array(args.num) DATA = random_array(args.num)
hashes = [hash_dp(dp) for dp in DATA] hashes = [hash_dp(dp) for dp in DATA]
print(hashes) print(hashes)
......
...@@ -16,6 +16,12 @@ inline int read_int32(char** p) { ...@@ -16,6 +16,12 @@ inline int read_int32(char** p) {
*p += 4; *p += 4;
return *pi; return *pi;
} }
inline tensorflow::int64 read_int64(char** p) {
auto pi = reinterpret_cast<const long long*>(*p);
*p += 8;
return *pi;
}
} }
namespace tensorpack { namespace tensorpack {
...@@ -26,7 +32,7 @@ struct RecvTensorList { ...@@ -26,7 +32,7 @@ struct RecvTensorList {
struct TensorConstructor { struct TensorConstructor {
tensorflow::DataType dtype; tensorflow::DataType dtype;
tensorflow::TensorShape shape; tensorflow::TensorShape shape;
int size; // TODO bufsize tensorflow::int64 buf_size;
char* buf; char* buf;
}; };
...@@ -46,8 +52,9 @@ class ZMQConnection { ...@@ -46,8 +52,9 @@ class ZMQConnection {
// https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels // https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels
// zmq socket is not thread safe // zmq socket is not thread safe
tensorflow::mutex_lock lk(mu_); tensorflow::mutex_lock lk(mu_);
bool succ = sock_.recv(&tlist->message); // TODO this may throw bool succ = sock_.recv(&tlist->message); // block until some data appears
// possible error code: http://api.zeromq.org/3-3:zmq-msg-recv // TODO this may throw, handle exception?
// Possible error code: http://api.zeromq.org/3-3:zmq-msg-recv
// succ=false only if EAGAIN // succ=false only if EAGAIN
CHECK(succ); // no EAGAIN, because we are blocking CHECK(succ); // no EAGAIN, because we are blocking
} }
...@@ -68,9 +75,9 @@ class ZMQConnection { ...@@ -68,9 +75,9 @@ class ZMQConnection {
int shp = read_int32(&pos); int shp = read_int32(&pos);
tensors[i].shape.AddDim(shp); tensors[i].shape.AddDim(shp);
} }
int sz = read_int32(&pos); tensorflow::int64 sz = read_int64(&pos);
tensors[i].buf = pos; tensors[i].buf = pos;
tensors[i].size = sz; tensors[i].buf_size = sz;
pos += sz; pos += sz;
} }
} }
......
...@@ -17,17 +17,18 @@ __all__ = ['zmq_recv', 'dumps_zmq_op', ...@@ -17,17 +17,18 @@ __all__ = ['zmq_recv', 'dumps_zmq_op',
'dump_tensor_protos', 'to_tensor_proto'] 'dump_tensor_protos', 'to_tensor_proto']
# TODO '.so' for linux only
def build(): def build():
global zmq_recv global zmq_recv
ret = compile() file_dir = os.path.dirname(os.path.abspath(__file__))
if ret != 0: basename = 'zmq_recv_op' + get_ext_suffix()
zmq_recv = None so_file = os.path.join(file_dir, basename)
else: if not os.path.isfile(so_file):
file_dir = os.path.dirname(os.path.abspath(__file__)) ret = compile()
recv_mod = tf.load_op_library( if ret != 0:
os.path.join(file_dir, 'zmq_recv_op' + get_ext_suffix())) raise RuntimeError("tensorpack user_ops compilation failed!")
zmq_recv = recv_mod.zmq_recv
recv_mod = tf.load_op_library(so_file)
zmq_recv = recv_mod.zmq_recv
build() build()
...@@ -43,7 +44,6 @@ _DTYPE_DICT = { ...@@ -43,7 +44,6 @@ _DTYPE_DICT = {
_DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()} _DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()}
# TODO support string tensor and scalar
def to_tensor_proto(arr): def to_tensor_proto(arr):
""" """
Convert a numpy array to TensorProto Convert a numpy array to TensorProto
...@@ -51,8 +51,15 @@ def to_tensor_proto(arr): ...@@ -51,8 +51,15 @@ def to_tensor_proto(arr):
Args: Args:
arr: numpy.ndarray. only supports common numerical types arr: numpy.ndarray. only supports common numerical types
""" """
if isinstance(arr, float):
arr = np.asarray(arr).astype('float32')
elif isinstance(arr, int):
arr = np.asarray(arr).astype('int32')
assert isinstance(arr, np.ndarray), type(arr) assert isinstance(arr, np.ndarray), type(arr)
dtype = _DTYPE_DICT[arr.dtype] try:
dtype = _DTYPE_DICT[arr.dtype]
except KeyError:
raise KeyError("Dtype {} is unsupported by current ZMQ Op!".format(arr.dtype))
ret = TensorProto() ret = TensorProto()
shape = ret.tensor_shape shape = ret.tensor_shape
...@@ -83,9 +90,8 @@ def dump_tensor_protos(protos): ...@@ -83,9 +90,8 @@ def dump_tensor_protos(protos):
Where each tensor is: Where each tensor is:
[dtype(int32)][ndims(int32)][shape[0](int32)]...[shape[n](int32)] [dtype(int32)][ndims(int32)][shape[0](int32)]...[shape[n](int32)]
[len(buffer)(int32)][buffer] [len(buffer)(int64)][buffer]
""" """
# TODO use int64
s = struct.pack('=i', len(protos)) s = struct.pack('=i', len(protos))
for p in protos: for p in protos:
...@@ -96,7 +102,7 @@ def dump_tensor_protos(protos): ...@@ -96,7 +102,7 @@ def dump_tensor_protos(protos):
s += struct.pack('=i', len(dims)) s += struct.pack('=i', len(dims))
for k in dims: for k in dims:
s += struct.pack('=i', k.size) s += struct.pack('=i', k.size)
s += struct.pack('=i', len(tensor_content)) # won't send stuff over 2G s += struct.pack('=q', len(tensor_content))
s += tensor_content s += tensor_content
return s return s
...@@ -111,5 +117,6 @@ def dumps_zmq_op(dp): ...@@ -111,5 +117,6 @@ def dumps_zmq_op(dp):
Returns: Returns:
a binary string a binary string
""" """
assert isinstance(dp, (list, tuple))
protos = [to_tensor_proto(arr) for arr in dp] protos = [to_tensor_proto(arr) for arr in dp]
return dump_tensor_protos(protos) return dump_tensor_protos(protos)
...@@ -27,7 +27,6 @@ The serialization format is a tensorpack custom format, defined in 'zmq_recv.py' ...@@ -27,7 +27,6 @@ The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'
namespace tensorpack { namespace tensorpack {
class ZMQRecvOp: public AsyncOpKernel { class ZMQRecvOp: public AsyncOpKernel {
public: public:
explicit ZMQRecvOp(OpKernelConstruction* context) : AsyncOpKernel(context) { explicit ZMQRecvOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
...@@ -39,6 +38,7 @@ class ZMQRecvOp: public AsyncOpKernel { ...@@ -39,6 +38,7 @@ class ZMQRecvOp: public AsyncOpKernel {
int hwm; int hwm;
OP_REQUIRES_OK(context, context->GetAttr("hwm", &hwm)); OP_REQUIRES_OK(context, context->GetAttr("hwm", &hwm));
// will get called only at the first sess.run call
conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL, hwm)); conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL, hwm));
} }
...@@ -61,15 +61,16 @@ class ZMQRecvOp: public AsyncOpKernel { ...@@ -61,15 +61,16 @@ class ZMQRecvOp: public AsyncOpKernel {
auto recv_dtype = tensors[j].dtype; auto recv_dtype = tensors[j].dtype;
OP_REQUIRES_ASYNC( OP_REQUIRES_ASYNC(
ctx, component_types_[j] == recv_dtype, ctx, component_types_[j] == recv_dtype,
errors::InvalidArgument("Type mismatch between parsed tensor (", errors::InvalidArgument("Type mismatch at index ", std::to_string(j),
DataTypeString(recv_dtype), ") and dtype (", " between received tensor (", DataTypeString(recv_dtype),
DataTypeString(component_types_[j]), ")"), done); ") and dtype (", DataTypeString(component_types_[j]), ")"),
done);
TensorShape& shape = tensors[j].shape; TensorShape& shape = tensors[j].shape;
OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(i, shape, &output), done); OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(i, shape, &output), done);
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()}); auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()}).data();
memcpy(ptr.data(), tensors[j].buf, tensors[j].size); memcpy(ptr, tensors[j].buf, tensors[j].buf_size);
outputs.set(j, *output); outputs.set(j, *output);
} }
done(); done();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment