Commit fc96b3b9 authored by Yuxin Wu's avatar Yuxin Wu

speed up zmq op

parent 88d7baeb
*.gz
*.npy
# tensorpack-specific stuff
train_log
tensorpack/user_ops/obj
*.npy
*.bin
*.tfmodel
*.meta
*.log*
model-*
.gitignore
*.caffemodel
*.png
*.jpg
checkpoint
*.json
*.prototxt
*.txt
# my personal stuff
snippet
examples/private
TODO.md
*.gz
# Byte-compiled / optimized / DLL files
__pycache__/
......@@ -61,21 +83,3 @@ docs/_build/
target/
*.dat
*.bin
*.tfmodel
*.meta
*.log*
model-*
.gitignore
*.caffemodel
*.png
*.jpg
checkpoint
*.json
*.prototxt
*.txt
# my personal stuff
snippet
examples/private
TODO.md
......@@ -47,7 +47,7 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format='msgpack'):
q.append(time.time() - start)
pbar.update(1)
if pbar.n % print_interval == 0:
pbar.write("Avg send time: {}".format(sum(q) / len(q)))
pbar.write("Avg send time @{}: {}".format(pbar.n, sum(q) / len(q)))
finally:
socket.setsockopt(zmq.LINGER, 0)
socket.close()
......
......@@ -300,7 +300,7 @@ class ZMQInput(FeedfreeInput):
def get_input_tensors(self):
from tensorpack.user_ops import zmq_recv
ret = zmq_recv(self._endpoint, [x.dtype for x in self.input_placehdrs])
if isinstance(self._recv, tf.Tensor):
if isinstance(ret, tf.Tensor):
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
......
......@@ -89,7 +89,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
self._input_method = QueueInput(config.dataflow, input_queue)
else:
self._input_method = config.data
assert isinstance(self._input_method, QueueInput)
# assert isinstance(self._input_method, QueueInput)
if predict_tower is not None:
log_deprecated("Argument `predict_tower` in trainer", "Use TrainConfig(predict_tower=...) instead!")
......
......@@ -22,8 +22,8 @@ ENDPOINT = 'ipc://test-pipe'
DATA = []
for k in range(num):
arr1 = np.random.rand(k).astype('float32')
arr2 = (np.random.rand(k * 2) * 10).astype('uint8')
arr1 = np.random.rand(k + 10, k + 10).astype('float32')
arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8')
DATA.append([arr1, arr2])
......
......@@ -6,51 +6,62 @@
#include <string>
#include <iostream>
#include <tensorflow/core/framework/tensor.pb.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h>
#include "zmq.hpp"
namespace {
inline int read_int32(const char* p) {
auto pi = reinterpret_cast<const int*>(p);
inline int read_int32(char** p) {
auto pi = reinterpret_cast<const int*>(*p);
*p += 4;
return *pi;
}
}
struct RecvTensorList {
zmq::message_t message;
struct TensorConstructor {
// TODO make it allocated on stack
// only contains shape and type
tensorflow::TensorProto meta;
char* buf;
int size;
};
tensorflow::gtl::InlinedVector<TensorConstructor, 4> tensors;
};
class ZMQConnection {
public:
ZMQConnection(std::string endpoint, int zmq_socket_type):
ctx_(1), sock_(ctx_, zmq_socket_type) {
sock_.bind(endpoint.c_str());
int hwm = 100; // TODO make it an option
sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm);
sock_.bind(endpoint.c_str());
}
tensorflow::TensorProto recv_tensor() {
zmq::message_t message;
bool succ = sock_.recv(&message);
CHECK(succ); // no EAGAIN, because we are blocking
tensorflow::TensorProto ret{};
CHECK(ret.ParseFromArray(message.data(), message.size()));
return ret;
}
std::vector<tensorflow::TensorProto> recv_tensor_list() {
zmq::message_t message;
void recv_tensor_list(RecvTensorList* tlist) {
// TODO critical section
bool succ = sock_.recv(&message);
bool succ = sock_.recv(&tlist->message);
CHECK(succ); // no EAGAIN, because we are blocking
char* pos = reinterpret_cast<char*>(message.data());
char* pos = reinterpret_cast<char*>(tlist->message.data());
int num = read_int32(pos);
int num = read_int32(&pos);
auto& tensors = tlist->tensors;
tensors.resize(num);
CHECK_LE(num, 15); // probably a format error
std::vector<tensorflow::TensorProto> ret(num);
pos += sizeof(int);
for (int i = 0; i < num; ++i) {
int size = read_int32(pos);
pos += sizeof(int);
//std::cout << "Message size:" << size << std::endl;
CHECK(ret[i].ParseFromArray(pos, size));
pos += size;
int sz = read_int32(&pos);
CHECK(tensors[i].meta.ParseFromArray(pos, sz));
pos += sz;
sz = read_int32(&pos);
tensors[i].buf = pos;
tensors[i].size = sz;
pos += sz;
}
return ret;
}
private:
......
......@@ -19,7 +19,8 @@ REGISTER_OP("ZMQRecv")
.SetShapeFn(shape_inference::UnknownShape)
.SetIsStateful()
.Doc(R"doc(
Receive and return a serialized list of TensorProto from a ZMQ socket.
Receive a serialized list of Tensors from a ZMQ socket.
The serialization format is a tensorpack custom format.
)doc");
......@@ -27,7 +28,7 @@ class ZMQRecvOp: public OpKernel {
public:
explicit ZMQRecvOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_));
CHECK(conn_.get() == nullptr);
CHECK_EQ(conn_.get(), nullptr);
string endpoint;
OP_REQUIRES_OK(context, context->GetAttr("end_point", &endpoint));
......@@ -35,27 +36,34 @@ class ZMQRecvOp: public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
//GuardedTimer tm("Compute");
int start, stop;
TF_CHECK_OK(this->OutputRange("output", &start, &stop));
//cout << "COMPUTE" << endl;
auto protos = conn_->recv_tensor_list();
RecvTensorList tlist;
conn_->recv_tensor_list(&tlist);
auto& tensors = tlist.tensors;
OpOutputList outputs;
OP_REQUIRES_OK(ctx, ctx->output_list("output", &outputs));
CHECK(protos.size() == num_components());
CHECK(tensors.size() == num_components());
for (int i = start; i < stop; ++i) {
Tensor output;
Tensor* output = nullptr;
int j = i - start;
OP_REQUIRES_OK(ctx, ctx->device()->MakeTensorFromProto(
protos[j], ctx->output_alloc_attr(i), &output));
auto recv_dtype = tensors[j].meta.dtype();
OP_REQUIRES(
ctx, component_types_[j] == output.dtype(),
ctx, component_types_[j] == recv_dtype,
errors::InvalidArgument("Type mismatch between parsed tensor (",
DataTypeString(output.dtype()), ") and dtype (",
DataTypeString(recv_dtype), ") and dtype (",
DataTypeString(component_types_[j]), ")"));
outputs.set(j, output);
TensorShape shape{tensors[j].meta.tensor_shape()};
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output));
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()});
memcpy(ptr.data(), tensors[j].buf, tensors[j].size);
outputs.set(j, *output);
}
}
private:
......
......@@ -78,14 +78,22 @@ def dump_tensor_protos(protos):
protos (list): list of :class:`TensorProto` instance
Notes:
The format is: <#protos(int32)>|<size 1>|<serialized proto 1>|<size 2><serialized proto 2>| ...
The format is:
[#tensors(int32)]
(tensor1)[size of meta proto][serialized meta proto][size of buffer][buffer]
(tensor2)...
"""
s = struct.pack('=i', len(protos))
for p in protos:
tensor_content = p.tensor_content
p.tensor_content = b'xxx' # clear content
buf = p.SerializeToString()
s += struct.pack('=i', len(buf)) # won't send stuff over 2G
s += struct.pack('=i', len(buf))
s += buf
s += struct.pack('=i', len(tensor_content)) # won't send stuff over 2G
s += tensor_content
return s
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment