Commit 65c8b239 authored by Yuxin Wu's avatar Yuxin Wu

[ZMQ] use AsyncOpKernel; better tests; use mutex. (#362)

parent a0d60a64
...@@ -3,10 +3,11 @@ ...@@ -3,10 +3,11 @@
# File: test-recv-op.py # File: test-recv-op.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import sys
import os import os
import zmq import zmq
import argparse
import multiprocessing as mp import multiprocessing as mp
import time
import numpy as np import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa import tensorflow as tf # noqa
...@@ -19,27 +20,46 @@ from tensorpack.utils.concurrency import ( # noqa ...@@ -19,27 +20,46 @@ from tensorpack.utils.concurrency import ( # noqa
ENDPOINT = 'ipc://test-pipe' ENDPOINT = 'ipc://test-pipe'
if __name__ == '__main__':
try:
num = int(sys.argv[1])
except (ValueError, IndexError):
num = 10
DATA = [] def send(iterable, delay=0):
ctx = zmq.Context()
sok = ctx.socket(zmq.PUSH)
sok.bind(ENDPOINT)
for dp in iterable:
if delay > 0:
time.sleep(delay)
print("Sending data to socket..")
sok.send(dumps_zmq_op(dp))
time.sleep(999)
def random_array(num):
ret = []
for k in range(num): for k in range(num):
arr1 = np.random.rand(k + 10, k + 10).astype('float32') arr1 = np.random.rand(k + 10, k + 10).astype('float32')
arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8') arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8')
DATA.append([arr1, arr2]) ret.append([arr1, arr2])
return ret
def send():
ctx = zmq.Context()
sok = ctx.socket(zmq.PUSH)
sok.connect(ENDPOINT)
for dp in DATA: def hash_dp(dp):
sok.send(dumps_zmq_op(dp)) return sum([k.sum() for k in dp])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task', default='basic',
choices=['basic', 'tworecv'])
parser.add_argument('-n', '--num', type=int, default=10)
args = parser.parse_args()
if args.task == 'basic':
DATA = random_array(args.num)
p = mp.Process(target=send, args=(DATA,))
ensure_proc_terminate(p)
start_proc_mask_signal(p)
def recv():
sess = tf.Session() sess = tf.Session()
recv = zmq_recv(ENDPOINT, [tf.float32, tf.uint8]) recv = zmq_recv(ENDPOINT, [tf.float32, tf.uint8])
print(recv) print(recv)
...@@ -49,8 +69,23 @@ if __name__ == '__main__': ...@@ -49,8 +69,23 @@ if __name__ == '__main__':
assert (arr[0] == truth[0]).all() assert (arr[0] == truth[0]).all()
assert (arr[1] == truth[1]).all() assert (arr[1] == truth[1]).all()
p = mp.Process(target=send) p.join()
if args.task == 'tworecv':
DATA = random_array(args.num)
hashes = [hash_dp(dp) for dp in DATA]
print(hashes)
p = mp.Process(target=send, args=(DATA, 0.00))
ensure_proc_terminate(p) ensure_proc_terminate(p)
start_proc_mask_signal(p) start_proc_mask_signal(p)
recv()
p.join() sess = tf.Session()
recv1 = zmq_recv(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
recv2 = zmq_recv(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
print(recv1, recv2)
for i in range(args.num // 2):
res1, res2 = sess.run([recv1, recv2])
h1, h2 = hash_dp(res1), hash_dp(res2)
print("Recv ", i, h1, h2)
assert h1 in hashes and h2 in hashes
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <iostream> #include <iostream>
#include <tensorflow/core/framework/tensor_shape.h> #include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h> #include <tensorflow/core/lib/gtl/inlined_vector.h>
#include <tensorflow/core/platform/mutex.h>
#include "zmq.hpp" #include "zmq.hpp"
namespace { namespace {
...@@ -17,6 +18,8 @@ inline int read_int32(char** p) { ...@@ -17,6 +18,8 @@ inline int read_int32(char** p) {
} }
} }
namespace tensorpack {
struct RecvTensorList { struct RecvTensorList {
zmq::message_t message; zmq::message_t message;
...@@ -35,13 +38,19 @@ class ZMQConnection { ...@@ -35,13 +38,19 @@ class ZMQConnection {
ZMQConnection(std::string endpoint, int zmq_socket_type, int hwm): ZMQConnection(std::string endpoint, int zmq_socket_type, int hwm):
ctx_(1), sock_(ctx_, zmq_socket_type) { ctx_(1), sock_(ctx_, zmq_socket_type) {
sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm); sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm);
sock_.bind(endpoint.c_str()); sock_.connect(endpoint.c_str());
} }
void recv_tensor_list(RecvTensorList* tlist) { void recv_tensor_list(RecvTensorList* tlist) {
// TODO critical section {
bool succ = sock_.recv(&tlist->message); // https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels
// zmq socket is not thread safe
tensorflow::mutex_lock lk(mu_);
bool succ = sock_.recv(&tlist->message); // TODO this may throw
// possible error code: http://api.zeromq.org/3-3:zmq-msg-recv
// succ=false only if EAGAIN
CHECK(succ); // no EAGAIN, because we are blocking CHECK(succ); // no EAGAIN, because we are blocking
}
char* pos = reinterpret_cast<char*>(tlist->message.data()); char* pos = reinterpret_cast<char*>(tlist->message.data());
...@@ -67,6 +76,10 @@ class ZMQConnection { ...@@ -67,6 +76,10 @@ class ZMQConnection {
} }
private: private:
tensorflow::mutex mu_;
zmq::context_t ctx_; zmq::context_t ctx_;
zmq::socket_t sock_; zmq::socket_t sock_;
}; };
} // namespace tensorpack
...@@ -16,18 +16,21 @@ REGISTER_OP("ZMQRecv") ...@@ -16,18 +16,21 @@ REGISTER_OP("ZMQRecv")
.Output("output: types") .Output("output: types")
.Attr("end_point: string") .Attr("end_point: string")
.Attr("types: list(type) >= 1") .Attr("types: list(type) >= 1")
.Attr("hwm: int >= 1 = 100") .Attr("hwm: int >= 1 = 10")
.SetShapeFn(shape_inference::UnknownShape) .SetShapeFn(shape_inference::UnknownShape)
.SetIsStateful() .SetIsStateful()
.Doc(R"doc( .Doc(R"doc(
Receive a list of Tensors from a ZMQ socket. Receive a list of Tensors by connecting to a ZMQ socket and pull from it.
The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'. The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc"); )doc");
class ZMQRecvOp: public OpKernel { namespace tensorpack {
class ZMQRecvOp: public AsyncOpKernel {
public: public:
explicit ZMQRecvOp(OpKernelConstruction* context) : OpKernel(context) { explicit ZMQRecvOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_)); OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_));
CHECK_EQ(conn_.get(), nullptr); CHECK_EQ(conn_.get(), nullptr);
...@@ -39,36 +42,37 @@ class ZMQRecvOp: public OpKernel { ...@@ -39,36 +42,37 @@ class ZMQRecvOp: public OpKernel {
conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL, hwm)); conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL, hwm));
} }
void Compute(OpKernelContext* ctx) override { void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
//GuardedTimer tm("Compute"); //GuardedTimer tm("Compute");
int start, stop; int start, stop;
TF_CHECK_OK(this->OutputRange("output", &start, &stop)); OP_REQUIRES_OK_ASYNC(ctx, this->OutputRange("output", &start, &stop), done);
RecvTensorList tlist; RecvTensorList tlist;
conn_->recv_tensor_list(&tlist); conn_->recv_tensor_list(&tlist);
auto& tensors = tlist.tensors; auto& tensors = tlist.tensors;
OpOutputList outputs; OpOutputList outputs;
OP_REQUIRES_OK(ctx, ctx->output_list("output", &outputs)); OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", &outputs), done);
CHECK(tensors.size() == num_components()); CHECK(tensors.size() == num_components());
for (int i = start; i < stop; ++i) { for (int i = start; i < stop; ++i) {
Tensor* output = nullptr; Tensor* output = nullptr;
int j = i - start; int j = i - start;
auto recv_dtype = tensors[j].dtype; auto recv_dtype = tensors[j].dtype;
OP_REQUIRES( OP_REQUIRES_ASYNC(
ctx, component_types_[j] == recv_dtype, ctx, component_types_[j] == recv_dtype,
errors::InvalidArgument("Type mismatch between parsed tensor (", errors::InvalidArgument("Type mismatch between parsed tensor (",
DataTypeString(recv_dtype), ") and dtype (", DataTypeString(recv_dtype), ") and dtype (",
DataTypeString(component_types_[j]), ")")); DataTypeString(component_types_[j]), ")"), done);
TensorShape& shape = tensors[j].shape; TensorShape& shape = tensors[j].shape;
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output)); OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(i, shape, &output), done);
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()}); auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()});
memcpy(ptr.data(), tensors[j].buf, tensors[j].size); memcpy(ptr.data(), tensors[j].buf, tensors[j].size);
outputs.set(j, *output); outputs.set(j, *output);
} }
done();
} }
private: private:
DataTypeVector component_types_; DataTypeVector component_types_;
...@@ -77,4 +81,8 @@ class ZMQRecvOp: public OpKernel { ...@@ -77,4 +81,8 @@ class ZMQRecvOp: public OpKernel {
size_t num_components() const { return component_types_.size(); } size_t num_components() const { return component_types_.size(); }
}; };
REGISTER_KERNEL_BUILDER(Name("ZMQRecv").Device(DEVICE_CPU), ZMQRecvOp); REGISTER_KERNEL_BUILDER(Name("ZMQRecv").Device(DEVICE_CPU), ZMQRecvOp);
} // namespace tensorpack
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment