Commit 398cb933 authored by Yuxin Wu's avatar Yuxin Wu

[zmq op] get rid of protobuf dependency (#221)

parent d899c925
# $File: Makefile # $File: Makefile
# $Date: Wed Jun 17 20:52:38 2015 +0800 # $Date: Thu Aug 03 16:14:29 2017 -0700
OBJ_DIR = obj OBJ_DIR = obj
...@@ -16,7 +16,7 @@ OPTFLAGS ?= -O3 -march=native ...@@ -16,7 +16,7 @@ OPTFLAGS ?= -O3 -march=native
#OPTFLAGS ?= -g3 -fsanitize=leak -O2 -lubsan #OPTFLAGS ?= -g3 -fsanitize=leak -O2 -lubsan
# extra packages from pkg-config # extra packages from pkg-config
LIBS = protobuf libzmq LIBS = libzmq
INCLUDE_DIR += $(shell pkg-config --cflags $(LIBS)) INCLUDE_DIR += $(shell pkg-config --cflags $(LIBS))
LDFLAGS += $(shell pkg-config $(LIBS) --libs) LDFLAGS += $(shell pkg-config $(LIBS) --libs)
...@@ -24,7 +24,7 @@ CXXFLAGS += $(INCLUDE_DIR) ...@@ -24,7 +24,7 @@ CXXFLAGS += $(INCLUDE_DIR)
CXXFLAGS += -Wall -Wextra CXXFLAGS += -Wall -Wextra
CXXFLAGS += $(DEFINES) -std=c++11 $(OPTFLAGS) -fPIC CXXFLAGS += $(DEFINES) -std=c++11 $(OPTFLAGS) -fPIC
# https://github.com/tensorflow/tensorflow/issues/1569 # TODO https://github.com/tensorflow/tensorflow/issues/1569
# You may need to disable this flag if you compile tensorflow yourself with gcc>=5 # You may need to disable this flag if you compile tensorflow yourself with gcc>=5
CXXFLAGS += -D_GLIBCXX_USE_CXX11_ABI=0 CXXFLAGS += -D_GLIBCXX_USE_CXX11_ABI=0
......
...@@ -16,6 +16,7 @@ print("Compiling user ops ...") ...@@ -16,6 +16,7 @@ print("Compiling user ops ...")
ret = os.system(compile_cmd) ret = os.system(compile_cmd)
if ret != 0: if ret != 0:
print("Failed to compile user ops!") print("Failed to compile user ops!")
zmq_recv = None
else: else:
recv_mod = tf.load_op_library(os.path.join(file_dir, 'zmq_recv_op.so')) recv_mod = tf.load_op_library(os.path.join(file_dir, 'zmq_recv_op.so'))
# TODO trigger recompile when load fails # TODO trigger recompile when load fails
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <string> #include <string>
#include <iostream> #include <iostream>
#include <tensorflow/core/framework/tensor.pb.h> #include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h> #include <tensorflow/core/lib/gtl/inlined_vector.h>
#include "zmq.hpp" #include "zmq.hpp"
...@@ -21,11 +21,10 @@ struct RecvTensorList { ...@@ -21,11 +21,10 @@ struct RecvTensorList {
zmq::message_t message; zmq::message_t message;
struct TensorConstructor { struct TensorConstructor {
// TODO make it allocated on stack tensorflow::DataType dtype;
// only contains shape and type tensorflow::TensorShape shape;
tensorflow::TensorProto meta; int size; // TODO bufsize
char* buf; char* buf;
int size;
}; };
tensorflow::gtl::InlinedVector<TensorConstructor, 4> tensors; tensorflow::gtl::InlinedVector<TensorConstructor, 4> tensors;
...@@ -53,11 +52,15 @@ class ZMQConnection { ...@@ -53,11 +52,15 @@ class ZMQConnection {
CHECK_LE(num, 15); // probably a format error CHECK_LE(num, 15); // probably a format error
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
int dt = read_int32(&pos);
tensors[i].dtype = tensorflow::DataType(dt);
int ndim = read_int32(&pos);
CHECK_LE(ndim, 8); // probably an error.
for (int k = 0; k < ndim; ++k) {
int shp = read_int32(&pos);
tensors[i].shape.AddDim(shp);
}
int sz = read_int32(&pos); int sz = read_int32(&pos);
CHECK(tensors[i].meta.ParseFromArray(pos, sz));
pos += sz;
sz = read_int32(&pos);
tensors[i].buf = pos; tensors[i].buf = pos;
tensors[i].size = sz; tensors[i].size = sz;
pos += sz; pos += sz;
......
...@@ -51,7 +51,7 @@ class ZMQRecvOp: public OpKernel { ...@@ -51,7 +51,7 @@ class ZMQRecvOp: public OpKernel {
for (int i = start; i < stop; ++i) { for (int i = start; i < stop; ++i) {
Tensor* output = nullptr; Tensor* output = nullptr;
int j = i - start; int j = i - start;
auto recv_dtype = tensors[j].meta.dtype(); auto recv_dtype = tensors[j].dtype;
OP_REQUIRES( OP_REQUIRES(
ctx, component_types_[j] == recv_dtype, ctx, component_types_[j] == recv_dtype,
errors::InvalidArgument("Type mismatch between parsed tensor (", errors::InvalidArgument("Type mismatch between parsed tensor (",
...@@ -59,7 +59,7 @@ class ZMQRecvOp: public OpKernel { ...@@ -59,7 +59,7 @@ class ZMQRecvOp: public OpKernel {
DataTypeString(component_types_[j]), ")")); DataTypeString(component_types_[j]), ")"));
TensorShape shape{tensors[j].meta.tensor_shape()}; TensorShape& shape = tensors[j].shape;
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output)); OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output));
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()}); auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()});
memcpy(ptr.data(), tensors[j].buf, tensors[j].size); memcpy(ptr.data(), tensors[j].buf, tensors[j].size);
......
...@@ -81,17 +81,24 @@ def dump_tensor_protos(protos): ...@@ -81,17 +81,24 @@ def dump_tensor_protos(protos):
The format is: The format is:
[#tensors(int32)] [#tensors(int32)]
(tensor1)[size of meta proto][serialized meta proto][size of buffer][buffer] [tensor1][tensor2]...
(tensor2)...
Where each tensor is:
[dtype(int32)][ndims(int32)][shape[0](int32)]...[shape[n](int32)]
[len(buffer)(int32)][buffer]
""" """
# TODO use int64
s = struct.pack('=i', len(protos)) s = struct.pack('=i', len(protos))
for p in protos: for p in protos:
tensor_content = p.tensor_content tensor_content = p.tensor_content
p.tensor_content = b'xxx' # clear content
buf = p.SerializeToString() s += struct.pack('=i', int(p.dtype))
s += struct.pack('=i', len(buf)) dims = p.tensor_shape.dim
s += buf s += struct.pack('=i', len(dims))
for k in dims:
s += struct.pack('=i', k.size)
s += struct.pack('=i', len(tensor_content)) # won't send stuff over 2G s += struct.pack('=i', len(tensor_content)) # won't send stuff over 2G
s += tensor_content s += tensor_content
return s return s
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment