Commit 398cb933 authored by Yuxin Wu's avatar Yuxin Wu

[zmq op] get rid of protobuf dependency (#221)

parent d899c925
# $File: Makefile
# $Date: Wed Jun 17 20:52:38 2015 +0800
# $Date: Thu Aug 03 16:14:29 2017 -0700
OBJ_DIR = obj
......@@ -16,7 +16,7 @@ OPTFLAGS ?= -O3 -march=native
#OPTFLAGS ?= -g3 -fsanitize=leak -O2 -lubsan
# extra packages from pkg-config
LIBS = protobuf libzmq
LIBS = libzmq
INCLUDE_DIR += $(shell pkg-config --cflags $(LIBS))
LDFLAGS += $(shell pkg-config $(LIBS) --libs)
......@@ -24,7 +24,7 @@ CXXFLAGS += $(INCLUDE_DIR)
CXXFLAGS += -Wall -Wextra
CXXFLAGS += $(DEFINES) -std=c++11 $(OPTFLAGS) -fPIC
# https://github.com/tensorflow/tensorflow/issues/1569
# TODO https://github.com/tensorflow/tensorflow/issues/1569
# You may need to disable this flag if you compile tensorflow yourself with gcc>=5
CXXFLAGS += -D_GLIBCXX_USE_CXX11_ABI=0
......
......@@ -16,6 +16,7 @@ print("Compiling user ops ...")
ret = os.system(compile_cmd)
if ret != 0:
print("Failed to compile user ops!")
zmq_recv = None
else:
recv_mod = tf.load_op_library(os.path.join(file_dir, 'zmq_recv_op.so'))
# TODO trigger recompile when load fails
......
......@@ -5,7 +5,7 @@
#include <string>
#include <iostream>
#include <tensorflow/core/framework/tensor.pb.h>
#include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h>
#include "zmq.hpp"
......@@ -21,11 +21,10 @@ struct RecvTensorList {
zmq::message_t message;
struct TensorConstructor {
// TODO make it allocated on stack
// only contains shape and type
tensorflow::TensorProto meta;
tensorflow::DataType dtype;
tensorflow::TensorShape shape;
int size; // TODO bufsize
char* buf;
int size;
};
tensorflow::gtl::InlinedVector<TensorConstructor, 4> tensors;
......@@ -53,11 +52,15 @@ class ZMQConnection {
CHECK_LE(num, 15); // probably a format error
for (int i = 0; i < num; ++i) {
int dt = read_int32(&pos);
tensors[i].dtype = tensorflow::DataType(dt);
int ndim = read_int32(&pos);
CHECK_LE(ndim, 8); // probably an error.
for (int k = 0; k < ndim; ++k) {
int shp = read_int32(&pos);
tensors[i].shape.AddDim(shp);
}
int sz = read_int32(&pos);
CHECK(tensors[i].meta.ParseFromArray(pos, sz));
pos += sz;
sz = read_int32(&pos);
tensors[i].buf = pos;
tensors[i].size = sz;
pos += sz;
......
......@@ -51,7 +51,7 @@ class ZMQRecvOp: public OpKernel {
for (int i = start; i < stop; ++i) {
Tensor* output = nullptr;
int j = i - start;
auto recv_dtype = tensors[j].meta.dtype();
auto recv_dtype = tensors[j].dtype;
OP_REQUIRES(
ctx, component_types_[j] == recv_dtype,
errors::InvalidArgument("Type mismatch between parsed tensor (",
......@@ -59,7 +59,7 @@ class ZMQRecvOp: public OpKernel {
DataTypeString(component_types_[j]), ")"));
TensorShape shape{tensors[j].meta.tensor_shape()};
TensorShape& shape = tensors[j].shape;
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output));
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()});
memcpy(ptr.data(), tensors[j].buf, tensors[j].size);
......
......@@ -81,17 +81,24 @@ def dump_tensor_protos(protos):
The format is:
[#tensors(int32)]
(tensor1)[size of meta proto][serialized meta proto][size of buffer][buffer]
(tensor2)...
[tensor1][tensor2]...
Where each tensor is:
[dtype(int32)][ndims(int32)][shape[0](int32)]...[shape[n](int32)]
[len(buffer)(int32)][buffer]
"""
# TODO use int64
s = struct.pack('=i', len(protos))
for p in protos:
tensor_content = p.tensor_content
p.tensor_content = b'xxx' # clear content
buf = p.SerializeToString()
s += struct.pack('=i', len(buf))
s += buf
s += struct.pack('=i', int(p.dtype))
dims = p.tensor_shape.dim
s += struct.pack('=i', len(dims))
for k in dims:
s += struct.pack('=i', k.size)
s += struct.pack('=i', len(tensor_content)) # won't send stuff over 2G
s += tensor_content
return s
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment