Commit fc96b3b9 authored by Yuxin Wu's avatar Yuxin Wu

speed up zmq op

parent 88d7baeb
*.gz # tensorpack-specific stuff
*.npy
train_log train_log
tensorpack/user_ops/obj
*.npy
*.bin
*.tfmodel
*.meta
*.log*
model-*
.gitignore
*.caffemodel
*.png
*.jpg
checkpoint
*.json
*.prototxt
*.txt
# my personal stuff
snippet
examples/private
TODO.md
*.gz
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
...@@ -61,21 +83,3 @@ docs/_build/ ...@@ -61,21 +83,3 @@ docs/_build/
target/ target/
*.dat *.dat
*.bin
*.tfmodel
*.meta
*.log*
model-*
.gitignore
*.caffemodel
*.png
*.jpg
checkpoint
*.json
*.prototxt
*.txt
# my personal stuff
snippet
examples/private
TODO.md
...@@ -47,7 +47,7 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format='msgpack'): ...@@ -47,7 +47,7 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format='msgpack'):
q.append(time.time() - start) q.append(time.time() - start)
pbar.update(1) pbar.update(1)
if pbar.n % print_interval == 0: if pbar.n % print_interval == 0:
pbar.write("Avg send time: {}".format(sum(q) / len(q))) pbar.write("Avg send time @{}: {}".format(pbar.n, sum(q) / len(q)))
finally: finally:
socket.setsockopt(zmq.LINGER, 0) socket.setsockopt(zmq.LINGER, 0)
socket.close() socket.close()
......
...@@ -300,7 +300,7 @@ class ZMQInput(FeedfreeInput): ...@@ -300,7 +300,7 @@ class ZMQInput(FeedfreeInput):
def get_input_tensors(self): def get_input_tensors(self):
from tensorpack.user_ops import zmq_recv from tensorpack.user_ops import zmq_recv
ret = zmq_recv(self._endpoint, [x.dtype for x in self.input_placehdrs]) ret = zmq_recv(self._endpoint, [x.dtype for x in self.input_placehdrs])
if isinstance(self._recv, tf.Tensor): if isinstance(ret, tf.Tensor):
ret = [ret] ret = [ret]
assert len(ret) == len(self.input_placehdrs) assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs): for qv, v in zip(ret, self.input_placehdrs):
......
...@@ -89,7 +89,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -89,7 +89,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
self._input_method = QueueInput(config.dataflow, input_queue) self._input_method = QueueInput(config.dataflow, input_queue)
else: else:
self._input_method = config.data self._input_method = config.data
assert isinstance(self._input_method, QueueInput) # assert isinstance(self._input_method, QueueInput)
if predict_tower is not None: if predict_tower is not None:
log_deprecated("Argument `predict_tower` in trainer", "Use TrainConfig(predict_tower=...) instead!") log_deprecated("Argument `predict_tower` in trainer", "Use TrainConfig(predict_tower=...) instead!")
......
...@@ -22,8 +22,8 @@ ENDPOINT = 'ipc://test-pipe' ...@@ -22,8 +22,8 @@ ENDPOINT = 'ipc://test-pipe'
DATA = [] DATA = []
for k in range(num): for k in range(num):
arr1 = np.random.rand(k).astype('float32') arr1 = np.random.rand(k + 10, k + 10).astype('float32')
arr2 = (np.random.rand(k * 2) * 10).astype('uint8') arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8')
DATA.append([arr1, arr2]) DATA.append([arr1, arr2])
......
...@@ -6,51 +6,62 @@ ...@@ -6,51 +6,62 @@
#include <string> #include <string>
#include <iostream> #include <iostream>
#include <tensorflow/core/framework/tensor.pb.h> #include <tensorflow/core/framework/tensor.pb.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h>
#include "zmq.hpp" #include "zmq.hpp"
namespace { namespace {
inline int read_int32(const char* p) { inline int read_int32(char** p) {
auto pi = reinterpret_cast<const int*>(p); auto pi = reinterpret_cast<const int*>(*p);
*p += 4;
return *pi; return *pi;
} }
} }
struct RecvTensorList {
zmq::message_t message;
struct TensorConstructor {
// TODO make it allocated on stack
// only contains shape and type
tensorflow::TensorProto meta;
char* buf;
int size;
};
tensorflow::gtl::InlinedVector<TensorConstructor, 4> tensors;
};
class ZMQConnection { class ZMQConnection {
public: public:
ZMQConnection(std::string endpoint, int zmq_socket_type): ZMQConnection(std::string endpoint, int zmq_socket_type):
ctx_(1), sock_(ctx_, zmq_socket_type) { ctx_(1), sock_(ctx_, zmq_socket_type) {
sock_.bind(endpoint.c_str()); int hwm = 100; // TODO make it an option
sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm);
sock_.bind(endpoint.c_str());
} }
tensorflow::TensorProto recv_tensor() { void recv_tensor_list(RecvTensorList* tlist) {
zmq::message_t message;
bool succ = sock_.recv(&message);
CHECK(succ); // no EAGAIN, because we are blocking
tensorflow::TensorProto ret{};
CHECK(ret.ParseFromArray(message.data(), message.size()));
return ret;
}
std::vector<tensorflow::TensorProto> recv_tensor_list() {
zmq::message_t message;
// TODO critical section // TODO critical section
bool succ = sock_.recv(&message); bool succ = sock_.recv(&tlist->message);
CHECK(succ); // no EAGAIN, because we are blocking CHECK(succ); // no EAGAIN, because we are blocking
char* pos = reinterpret_cast<char*>(message.data()); char* pos = reinterpret_cast<char*>(tlist->message.data());
int num = read_int32(pos); int num = read_int32(&pos);
auto& tensors = tlist->tensors;
tensors.resize(num);
CHECK_LE(num, 15); // probably a format error
std::vector<tensorflow::TensorProto> ret(num);
pos += sizeof(int);
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
int size = read_int32(pos); int sz = read_int32(&pos);
pos += sizeof(int); CHECK(tensors[i].meta.ParseFromArray(pos, sz));
//std::cout << "Message size:" << size << std::endl; pos += sz;
CHECK(ret[i].ParseFromArray(pos, size));
pos += size; sz = read_int32(&pos);
tensors[i].buf = pos;
tensors[i].size = sz;
pos += sz;
} }
return ret;
} }
private: private:
......
...@@ -19,7 +19,8 @@ REGISTER_OP("ZMQRecv") ...@@ -19,7 +19,8 @@ REGISTER_OP("ZMQRecv")
.SetShapeFn(shape_inference::UnknownShape) .SetShapeFn(shape_inference::UnknownShape)
.SetIsStateful() .SetIsStateful()
.Doc(R"doc( .Doc(R"doc(
Receive and return a serialized list of TensorProto from a ZMQ socket. Receive a serialized list of Tensors from a ZMQ socket.
The serialization format is a tensorpack custom format.
)doc"); )doc");
...@@ -27,7 +28,7 @@ class ZMQRecvOp: public OpKernel { ...@@ -27,7 +28,7 @@ class ZMQRecvOp: public OpKernel {
public: public:
explicit ZMQRecvOp(OpKernelConstruction* context) : OpKernel(context) { explicit ZMQRecvOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_)); OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_));
CHECK(conn_.get() == nullptr); CHECK_EQ(conn_.get(), nullptr);
string endpoint; string endpoint;
OP_REQUIRES_OK(context, context->GetAttr("end_point", &endpoint)); OP_REQUIRES_OK(context, context->GetAttr("end_point", &endpoint));
...@@ -35,27 +36,34 @@ class ZMQRecvOp: public OpKernel { ...@@ -35,27 +36,34 @@ class ZMQRecvOp: public OpKernel {
} }
void Compute(OpKernelContext* ctx) override { void Compute(OpKernelContext* ctx) override {
//GuardedTimer tm("Compute");
int start, stop; int start, stop;
TF_CHECK_OK(this->OutputRange("output", &start, &stop)); TF_CHECK_OK(this->OutputRange("output", &start, &stop));
//cout << "COMPUTE" << endl; RecvTensorList tlist;
auto protos = conn_->recv_tensor_list(); conn_->recv_tensor_list(&tlist);
auto& tensors = tlist.tensors;
OpOutputList outputs; OpOutputList outputs;
OP_REQUIRES_OK(ctx, ctx->output_list("output", &outputs)); OP_REQUIRES_OK(ctx, ctx->output_list("output", &outputs));
CHECK(protos.size() == num_components()); CHECK(tensors.size() == num_components());
for (int i = start; i < stop; ++i) { for (int i = start; i < stop; ++i) {
Tensor output; Tensor* output = nullptr;
int j = i - start; int j = i - start;
OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto( auto recv_dtype = tensors[j].meta.dtype();
protos[j], ctx->output_alloc_attr(i), &output));
OP_REQUIRES( OP_REQUIRES(
ctx, component_types_[j] == output.dtype(), ctx, component_types_[j] == recv_dtype,
errors::InvalidArgument("Type mismatch between parsed tensor (", errors::InvalidArgument("Type mismatch between parsed tensor (",
DataTypeString(output.dtype()), ") and dtype (", DataTypeString(recv_dtype), ") and dtype (",
DataTypeString(component_types_[j]), ")")); DataTypeString(component_types_[j]), ")"));
outputs.set(j, output);
TensorShape shape{tensors[j].meta.tensor_shape()};
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output));
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()});
memcpy(ptr.data(), tensors[j].buf, tensors[j].size);
outputs.set(j, *output);
} }
} }
private: private:
......
...@@ -78,14 +78,22 @@ def dump_tensor_protos(protos): ...@@ -78,14 +78,22 @@ def dump_tensor_protos(protos):
protos (list): list of :class:`TensorProto` instance protos (list): list of :class:`TensorProto` instance
Notes: Notes:
The format is: <#protos(int32)>|<size 1>|<serialized proto 1>|<size 2><serialized proto 2>| ... The format is:
[#tensors(int32)]
(tensor1)[size of meta proto][serialized meta proto][size of buffer][buffer]
(tensor2)...
""" """
s = struct.pack('=i', len(protos)) s = struct.pack('=i', len(protos))
for p in protos: for p in protos:
tensor_content = p.tensor_content
p.tensor_content = b'xxx' # clear content
buf = p.SerializeToString() buf = p.SerializeToString()
s += struct.pack('=i', len(buf)) # won't send stuff over 2G s += struct.pack('=i', len(buf))
s += buf s += buf
s += struct.pack('=i', len(tensor_content)) # won't send stuff over 2G
s += tensor_content
return s return s
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment