Commit 2da6f9ed authored by Yuxin Wu's avatar Yuxin Wu

[ZMQ] correct so name; use int64; support scalar; (#362)

parent 65c8b239
......@@ -384,13 +384,15 @@ class ZMQInput(TensorInput):
"""
Not well implemented yet. Don't use.
"""
def __init__(self, endpoint):
def __init__(self, endpoint, hwm):
self._endpoint = endpoint
from tensorpack.user_ops import zmq_recv
def fn():
ret = zmq_recv(self._endpoint, [x.dtype for x in self.inputs_desc])
ret = zmq_recv(
self._endpoint, [x.dtype for x in self.inputs_desc],
hwm=hwm)
if isinstance(ret, tf.Tensor):
ret = [ret]
assert len(ret) == len(self.inputs_desc)
......
# $File: Makefile
# $Date: Tue Dec 12 18:04:22 2017 -0800
# $Date: Tue Dec 12 22:27:38 2017 -0800
OBJ_DIR = obj
PYTHON = python
......@@ -45,8 +45,8 @@ ccSOURCES = $(shell find $(SRCDIRS) -name "*.cc" | sed 's/^\.\///g')
OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o))
DEPFILES = $(OBJS:.o=.d)
# TODO what about mac?
SO = zmq_recv_op.so
EXT_SUFFIX ?= $(shell $(PYTHON) -c 'import sysconfig; print(sysconfig.get_config_var("EXT_SUFFIX"))')
SO = zmq_recv_op$(EXT_SUFFIX)
.PHONY: all clean
......@@ -56,7 +56,7 @@ ifneq ($(MAKECMDGOALS), clean)
sinclude $(DEPFILES)
endif
%.so: $(OBJ_DIR)/%.o
%$(EXT_SUFFIX): $(OBJ_DIR)/%.o
@echo "Linking $@ ..."
@$(CXX) $^ -o $@ $(LDFLAGS)
@echo "done."
......
......@@ -2,26 +2,16 @@
# -*- coding: utf-8 -*-
# File: common.py
from __future__ import print_function
import sysconfig
import tensorflow as tf
import os
def compile():
cxxflags = ' '.join(tf.sysconfig.get_compile_flags())
ldflags = ' '.join(tf.sysconfig.get_link_flags())
file_dir = os.path.dirname(os.path.abspath(__file__))
compile_cmd = 'TF_CXXFLAGS="{}" TF_LDFLAGS="{}" make -C "{}"'.format(
cxxflags, ldflags, file_dir)
ret = os.system(compile_cmd)
return ret
from ..utils import logger
# https://github.com/uber/horovod/blob/10835d25eccf4b198a23a0795edddf0896f6563d/horovod/tensorflow/mpi_ops.py#L30-L40
def get_ext_suffix():
"""Determine library extension for various versions of Python."""
return '.so' # TODO
ext_suffix = sysconfig.get_config_var('EXT_SUFFIX')
if ext_suffix:
return ext_suffix
......@@ -33,5 +23,17 @@ def get_ext_suffix():
return '.so'
def compile():
cxxflags = ' '.join(tf.sysconfig.get_compile_flags())
ldflags = ' '.join(tf.sysconfig.get_link_flags())
ext_suffix = get_ext_suffix()
file_dir = os.path.dirname(os.path.abspath(__file__))
compile_cmd = 'TF_CXXFLAGS="{}" TF_LDFLAGS="{}" EXT_SUFFIX="{}" make -C "{}"'.format(
cxxflags, ldflags, ext_suffix, file_dir)
logger.info("Compile user_ops by command " + compile_cmd + ' ...')
ret = os.system(compile_cmd)
return ret
if __name__ == '__main__':
compile()
......@@ -38,11 +38,18 @@ def random_array(num):
ret = []
for k in range(num):
arr1 = np.random.rand(k + 10, k + 10).astype('float32')
# arr1 = 3.0
arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8')
ret.append([arr1, arr2])
return ret
def constant_array(num):
arr = np.ones((30, 30)).astype('float32')
arr2 = np.ones((3, 3)).astype('uint8')
return [[arr, arr2]] * num
def hash_dp(dp):
return sum([k.sum() for k in dp])
......@@ -50,7 +57,7 @@ def hash_dp(dp):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task', default='basic',
choices=['basic', 'tworecv'])
choices=['basic', 'tworecv', 'send'])
parser.add_argument('-n', '--num', type=int, default=10)
args = parser.parse_args()
......@@ -68,10 +75,10 @@ if __name__ == '__main__':
arr = sess.run(recv)
assert (arr[0] == truth[0]).all()
assert (arr[1] == truth[1]).all()
p.join()
if args.task == 'tworecv':
elif args.task == 'send':
DATA = random_array(args.num)
send(DATA)
elif args.task == 'tworecv':
DATA = random_array(args.num)
hashes = [hash_dp(dp) for dp in DATA]
print(hashes)
......
......@@ -16,6 +16,12 @@ inline int read_int32(char** p) {
*p += 4;
return *pi;
}
inline tensorflow::int64 read_int64(char** p) {
auto pi = reinterpret_cast<const long long*>(*p);
*p += 8;
return *pi;
}
}
namespace tensorpack {
......@@ -26,7 +32,7 @@ struct RecvTensorList {
struct TensorConstructor {
tensorflow::DataType dtype;
tensorflow::TensorShape shape;
int size; // TODO bufsize
tensorflow::int64 buf_size;
char* buf;
};
......@@ -46,8 +52,9 @@ class ZMQConnection {
// https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels
// zmq socket is not thread safe
tensorflow::mutex_lock lk(mu_);
bool succ = sock_.recv(&tlist->message); // TODO this may throw
// possible error code: http://api.zeromq.org/3-3:zmq-msg-recv
bool succ = sock_.recv(&tlist->message); // block until some data appears
// TODO this may throw, handle exception?
// Possible error code: http://api.zeromq.org/3-3:zmq-msg-recv
// succ=false only if EAGAIN
CHECK(succ); // no EAGAIN, because we are blocking
}
......@@ -68,9 +75,9 @@ class ZMQConnection {
int shp = read_int32(&pos);
tensors[i].shape.AddDim(shp);
}
int sz = read_int32(&pos);
tensorflow::int64 sz = read_int64(&pos);
tensors[i].buf = pos;
tensors[i].size = sz;
tensors[i].buf_size = sz;
pos += sz;
}
}
......
......@@ -17,16 +17,17 @@ __all__ = ['zmq_recv', 'dumps_zmq_op',
'dump_tensor_protos', 'to_tensor_proto']
# TODO '.so' for linux only
def build():
global zmq_recv
file_dir = os.path.dirname(os.path.abspath(__file__))
basename = 'zmq_recv_op' + get_ext_suffix()
so_file = os.path.join(file_dir, basename)
if not os.path.isfile(so_file):
ret = compile()
if ret != 0:
zmq_recv = None
else:
file_dir = os.path.dirname(os.path.abspath(__file__))
recv_mod = tf.load_op_library(
os.path.join(file_dir, 'zmq_recv_op' + get_ext_suffix()))
raise RuntimeError("tensorpack user_ops compilation failed!")
recv_mod = tf.load_op_library(so_file)
zmq_recv = recv_mod.zmq_recv
......@@ -43,7 +44,6 @@ _DTYPE_DICT = {
_DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()}
# TODO support string tensor and scalar
def to_tensor_proto(arr):
"""
Convert a numpy array to TensorProto
......@@ -51,8 +51,15 @@ def to_tensor_proto(arr):
Args:
arr: numpy.ndarray. only supports common numerical types
"""
if isinstance(arr, float):
arr = np.asarray(arr).astype('float32')
elif isinstance(arr, int):
arr = np.asarray(arr).astype('int32')
assert isinstance(arr, np.ndarray), type(arr)
try:
dtype = _DTYPE_DICT[arr.dtype]
except KeyError:
raise KeyError("Dtype {} is unsupported by current ZMQ Op!".format(arr.dtype))
ret = TensorProto()
shape = ret.tensor_shape
......@@ -83,9 +90,8 @@ def dump_tensor_protos(protos):
Where each tensor is:
[dtype(int32)][ndims(int32)][shape[0](int32)]...[shape[n](int32)]
[len(buffer)(int32)][buffer]
[len(buffer)(int64)][buffer]
"""
# TODO use int64
s = struct.pack('=i', len(protos))
for p in protos:
......@@ -96,7 +102,7 @@ def dump_tensor_protos(protos):
s += struct.pack('=i', len(dims))
for k in dims:
s += struct.pack('=i', k.size)
s += struct.pack('=i', len(tensor_content)) # won't send stuff over 2G
s += struct.pack('=q', len(tensor_content))
s += tensor_content
return s
......@@ -111,5 +117,6 @@ def dumps_zmq_op(dp):
Returns:
a binary string
"""
assert isinstance(dp, (list, tuple))
protos = [to_tensor_proto(arr) for arr in dp]
return dump_tensor_protos(protos)
......@@ -27,7 +27,6 @@ The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'
namespace tensorpack {
class ZMQRecvOp: public AsyncOpKernel {
public:
explicit ZMQRecvOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
......@@ -39,6 +38,7 @@ class ZMQRecvOp: public AsyncOpKernel {
int hwm;
OP_REQUIRES_OK(context, context->GetAttr("hwm", &hwm));
// will get called only at the first sess.run call
conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL, hwm));
}
......@@ -61,15 +61,16 @@ class ZMQRecvOp: public AsyncOpKernel {
auto recv_dtype = tensors[j].dtype;
OP_REQUIRES_ASYNC(
ctx, component_types_[j] == recv_dtype,
errors::InvalidArgument("Type mismatch between parsed tensor (",
DataTypeString(recv_dtype), ") and dtype (",
DataTypeString(component_types_[j]), ")"), done);
errors::InvalidArgument("Type mismatch at index ", std::to_string(j),
" between received tensor (", DataTypeString(recv_dtype),
") and dtype (", DataTypeString(component_types_[j]), ")"),
done);
TensorShape& shape = tensors[j].shape;
OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(i, shape, &output), done);
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()});
memcpy(ptr.data(), tensors[j].buf, tensors[j].size);
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()}).data();
memcpy(ptr, tensors[j].buf, tensors[j].buf_size);
outputs.set(j, *output);
}
done();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment