Commit 65c8b239 authored by Yuxin Wu's avatar Yuxin Wu

[ZMQ] use AsyncOpKernel; better tests; use mutex. (#362)

parent a0d60a64
......@@ -3,10 +3,11 @@
# File: test-recv-op.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import sys
import os
import zmq
import argparse
import multiprocessing as mp
import time
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa
......@@ -19,27 +20,46 @@ from tensorpack.utils.concurrency import ( # noqa
ENDPOINT = 'ipc://test-pipe'
if __name__ == '__main__':
try:
num = int(sys.argv[1])
except (ValueError, IndexError):
num = 10
DATA = []
def send(iterable, delay=0):
ctx = zmq.Context()
sok = ctx.socket(zmq.PUSH)
sok.bind(ENDPOINT)
for dp in iterable:
if delay > 0:
time.sleep(delay)
print("Sending data to socket..")
sok.send(dumps_zmq_op(dp))
time.sleep(999)
def random_array(num):
ret = []
for k in range(num):
arr1 = np.random.rand(k + 10, k + 10).astype('float32')
arr2 = (np.random.rand((k + 10) * 2) * 10).astype('uint8')
DATA.append([arr1, arr2])
ret.append([arr1, arr2])
return ret
def send():
ctx = zmq.Context()
sok = ctx.socket(zmq.PUSH)
sok.connect(ENDPOINT)
for dp in DATA:
sok.send(dumps_zmq_op(dp))
def hash_dp(dp):
return sum([k.sum() for k in dp])
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task', default='basic',
choices=['basic', 'tworecv'])
parser.add_argument('-n', '--num', type=int, default=10)
args = parser.parse_args()
if args.task == 'basic':
DATA = random_array(args.num)
p = mp.Process(target=send, args=(DATA,))
ensure_proc_terminate(p)
start_proc_mask_signal(p)
def recv():
sess = tf.Session()
recv = zmq_recv(ENDPOINT, [tf.float32, tf.uint8])
print(recv)
......@@ -49,8 +69,23 @@ if __name__ == '__main__':
assert (arr[0] == truth[0]).all()
assert (arr[1] == truth[1]).all()
p = mp.Process(target=send)
p.join()
if args.task == 'tworecv':
DATA = random_array(args.num)
hashes = [hash_dp(dp) for dp in DATA]
print(hashes)
p = mp.Process(target=send, args=(DATA, 0.00))
ensure_proc_terminate(p)
start_proc_mask_signal(p)
recv()
p.join()
sess = tf.Session()
recv1 = zmq_recv(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
recv2 = zmq_recv(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
print(recv1, recv2)
for i in range(args.num // 2):
res1, res2 = sess.run([recv1, recv2])
h1, h2 = hash_dp(res1), hash_dp(res2)
print("Recv ", i, h1, h2)
assert h1 in hashes and h2 in hashes
......@@ -7,6 +7,7 @@
#include <iostream>
#include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h>
#include <tensorflow/core/platform/mutex.h>
#include "zmq.hpp"
namespace {
......@@ -17,6 +18,8 @@ inline int read_int32(char** p) {
}
}
namespace tensorpack {
struct RecvTensorList {
zmq::message_t message;
......@@ -35,13 +38,19 @@ class ZMQConnection {
ZMQConnection(std::string endpoint, int zmq_socket_type, int hwm):
ctx_(1), sock_(ctx_, zmq_socket_type) {
sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm);
sock_.bind(endpoint.c_str());
sock_.connect(endpoint.c_str());
}
void recv_tensor_list(RecvTensorList* tlist) {
// TODO critical section
bool succ = sock_.recv(&tlist->message);
{
// https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels
// zmq socket is not thread safe
tensorflow::mutex_lock lk(mu_);
bool succ = sock_.recv(&tlist->message); // TODO this may throw
// possible error code: http://api.zeromq.org/3-3:zmq-msg-recv
// succ=false only if EAGAIN
CHECK(succ); // no EAGAIN, because we are blocking
}
char* pos = reinterpret_cast<char*>(tlist->message.data());
......@@ -67,6 +76,10 @@ class ZMQConnection {
}
private:
tensorflow::mutex mu_;
zmq::context_t ctx_;
zmq::socket_t sock_;
};
} // namespace tensorpack
......@@ -16,18 +16,21 @@ REGISTER_OP("ZMQRecv")
.Output("output: types")
.Attr("end_point: string")
.Attr("types: list(type) >= 1")
.Attr("hwm: int >= 1 = 100")
.Attr("hwm: int >= 1 = 10")
.SetShapeFn(shape_inference::UnknownShape)
.SetIsStateful()
.Doc(R"doc(
Receive a list of Tensors from a ZMQ socket.
Receive a list of Tensors by connecting to a ZMQ socket and pull from it.
The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc");
class ZMQRecvOp: public OpKernel {
namespace tensorpack {
class ZMQRecvOp: public AsyncOpKernel {
public:
explicit ZMQRecvOp(OpKernelConstruction* context) : OpKernel(context) {
explicit ZMQRecvOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_));
CHECK_EQ(conn_.get(), nullptr);
......@@ -39,36 +42,37 @@ class ZMQRecvOp: public OpKernel {
conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL, hwm));
}
void Compute(OpKernelContext* ctx) override {
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
//GuardedTimer tm("Compute");
int start, stop;
TF_CHECK_OK(this->OutputRange("output", &start, &stop));
OP_REQUIRES_OK_ASYNC(ctx, this->OutputRange("output", &start, &stop), done);
RecvTensorList tlist;
conn_->recv_tensor_list(&tlist);
auto& tensors = tlist.tensors;
OpOutputList outputs;
OP_REQUIRES_OK(ctx, ctx->output_list("output", &outputs));
OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", &outputs), done);
CHECK(tensors.size() == num_components());
for (int i = start; i < stop; ++i) {
Tensor* output = nullptr;
int j = i - start;
auto recv_dtype = tensors[j].dtype;
OP_REQUIRES(
OP_REQUIRES_ASYNC(
ctx, component_types_[j] == recv_dtype,
errors::InvalidArgument("Type mismatch between parsed tensor (",
DataTypeString(recv_dtype), ") and dtype (",
DataTypeString(component_types_[j]), ")"));
DataTypeString(component_types_[j]), ")"), done);
TensorShape& shape = tensors[j].shape;
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output));
OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(i, shape, &output), done);
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()});
memcpy(ptr.data(), tensors[j].buf, tensors[j].size);
outputs.set(j, *output);
}
done();
}
private:
DataTypeVector component_types_;
......@@ -77,4 +81,8 @@ class ZMQRecvOp: public OpKernel {
size_t num_components() const { return component_types_.size(); }
};
REGISTER_KERNEL_BUILDER(Name("ZMQRecv").Device(DEVICE_CPU), ZMQRecvOp);
} // namespace tensorpack
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment