Commit 99ddd038 authored by Yuxin Wu's avatar Yuxin Wu

[ZMQ] use resource for ZMQ connection

parent 0594a9ad
...@@ -12,7 +12,7 @@ import numpy as np ...@@ -12,7 +12,7 @@ import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa import tensorflow as tf # noqa
from tensorpack.user_ops.zmq_recv import ( # noqa from tensorpack.user_ops.zmq_recv import ( # noqa
zmq_recv, dumps_zmq_op) ZMQRecv, dumps_zmq_op)
from tensorpack.utils.concurrency import ( # noqa from tensorpack.utils.concurrency import ( # noqa
start_proc_mask_signal, start_proc_mask_signal,
ensure_proc_terminate) ensure_proc_terminate)
...@@ -24,7 +24,7 @@ ENDPOINT = 'ipc://test-pipe' ...@@ -24,7 +24,7 @@ ENDPOINT = 'ipc://test-pipe'
def send(iterable, delay=0): def send(iterable, delay=0):
ctx = zmq.Context() ctx = zmq.Context()
sok = ctx.socket(zmq.PUSH) sok = ctx.socket(zmq.PUSH)
sok.bind(ENDPOINT) sok.connect(ENDPOINT)
for dp in iterable: for dp in iterable:
if delay > 0: if delay > 0:
...@@ -68,7 +68,7 @@ if __name__ == '__main__': ...@@ -68,7 +68,7 @@ if __name__ == '__main__':
start_proc_mask_signal(p) start_proc_mask_signal(p)
sess = tf.Session() sess = tf.Session()
recv = zmq_recv(ENDPOINT, [tf.float32, tf.uint8]) recv = ZMQRecv(ENDPOINT, [tf.float32, tf.uint8]).recv()
print(recv) print(recv)
for truth in DATA: for truth in DATA:
...@@ -87,8 +87,9 @@ if __name__ == '__main__': ...@@ -87,8 +87,9 @@ if __name__ == '__main__':
start_proc_mask_signal(p) start_proc_mask_signal(p)
sess = tf.Session() sess = tf.Session()
recv1 = zmq_recv(ENDPOINT, [tf.float32, tf.uint8], hwm=1) zmqsock = ZMQRecv(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
recv2 = zmq_recv(ENDPOINT, [tf.float32, tf.uint8], hwm=1) recv1 = zmqsock.recv()
recv2 = zmqsock.recv()
print(recv1, recv2) print(recv1, recv2)
for i in range(args.num // 2): for i in range(args.num // 2):
......
...@@ -5,8 +5,10 @@ ...@@ -5,8 +5,10 @@
#include <string> #include <string>
#include <iostream> #include <iostream>
#include <thread>
#include <tensorflow/core/framework/tensor_shape.h> #include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h> #include <tensorflow/core/lib/gtl/inlined_vector.h>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/platform/mutex.h> #include <tensorflow/core/platform/mutex.h>
#include "zmq.hpp" #include "zmq.hpp"
...@@ -39,14 +41,16 @@ struct RecvTensorList { ...@@ -39,14 +41,16 @@ struct RecvTensorList {
tensorflow::gtl::InlinedVector<TensorConstructor, 4> tensors; tensorflow::gtl::InlinedVector<TensorConstructor, 4> tensors;
}; };
class ZMQConnection { class ZMQConnection : public tensorflow::ResourceBase {
public: public:
ZMQConnection(std::string endpoint, int zmq_socket_type, int hwm): ZMQConnection(std::string endpoint, int zmq_socket_type, int hwm):
ctx_(1), sock_(ctx_, zmq_socket_type) { ctx_(1), sock_(ctx_, zmq_socket_type) {
sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm); sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm);
sock_.connect(endpoint.c_str()); sock_.bind(endpoint.c_str());
} }
std::string DebugString() override { return ""; }
void recv_tensor_list(RecvTensorList* tlist) { void recv_tensor_list(RecvTensorList* tlist) {
{ {
// https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels // https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels
......
...@@ -13,12 +13,14 @@ from tensorflow.core.framework import types_pb2 as DataType ...@@ -13,12 +13,14 @@ from tensorflow.core.framework import types_pb2 as DataType
from .common import compile, get_ext_suffix from .common import compile, get_ext_suffix
__all__ = ['zmq_recv', 'dumps_zmq_op', __all__ = ['dumps_zmq_op', 'ZMQRecv',
'dump_tensor_protos', 'to_tensor_proto'] 'dump_tensor_protos', 'to_tensor_proto']
def build(): _zmq_recv_mod = None
global zmq_recv
def try_build():
file_dir = os.path.dirname(os.path.abspath(__file__)) file_dir = os.path.dirname(os.path.abspath(__file__))
basename = 'zmq_recv_op' + get_ext_suffix() basename = 'zmq_recv_op' + get_ext_suffix()
so_file = os.path.join(file_dir, basename) so_file = os.path.join(file_dir, basename)
...@@ -27,11 +29,33 @@ def build(): ...@@ -27,11 +29,33 @@ def build():
if ret != 0: if ret != 0:
raise RuntimeError("tensorpack user_ops compilation failed!") raise RuntimeError("tensorpack user_ops compilation failed!")
recv_mod = tf.load_op_library(so_file) global _zmq_recv_mod
zmq_recv = recv_mod.zmq_recv _zmq_recv_mod = tf.load_op_library(so_file)
try_build()
class ZMQRecv(object):
def __init__(self, end_point, types, hwm=None, name=None):
self._types = types
if name is None:
self._name = (tf.get_default_graph()
.unique_name(self.__class__.__name__))
else:
self._name = name
self._zmq_handle = _zmq_recv_mod.zmq_connection(
end_point, hwm, shared_name=self._name)
@property
def name(self):
return self._name
build() def recv(self):
return _zmq_recv_mod.zmq_recv(
self._zmq_handle, self._types)
_DTYPE_DICT = { _DTYPE_DICT = {
......
...@@ -3,87 +3,120 @@ ...@@ -3,87 +3,120 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include <tensorflow/core/framework/op.h>
#include "tensorflow/core/framework/common_shape_fns.h" #include <tensorflow/core/framework/op_kernel.h>
#include <tensorflow/core/framework/resource_op_kernel.h>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/framework/common_shape_fns.h>
#include "zmq_conn.h" #include "zmq_conn.h"
using namespace std; using namespace std;
using namespace tensorflow; using namespace tensorflow;
REGISTER_OP("ZMQRecv") namespace tensorpack {
.Output("output: types")
.Attr("end_point: string")
.Attr("types: list(type) >= 1")
.Attr("hwm: int >= 1 = 10")
.SetShapeFn(shape_inference::UnknownShape)
.SetIsStateful()
.Doc(R"doc(
Receive a list of Tensors by connecting to a ZMQ socket and pull from it.
The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc");
// An op to create zmq connection as a resource.
// Use ResourceOpKernel to ensure singleton construction.
class ZMQConnectionHandleOp : public ResourceOpKernel<ZMQConnection> {
public:
explicit ZMQConnectionHandleOp(OpKernelConstruction* ctx)
: ResourceOpKernel<ZMQConnection>(ctx) {}
private:
Status CreateResource(ZMQConnection** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) {
const NodeDef& ndef = def();
string end_point;
int hwm;
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "end_point", &end_point));
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "hwm", &hwm));
*ret = new ZMQConnection(end_point, ZMQ_PULL, hwm);
return Status::OK();
}
// TODO verify
};
namespace tensorpack {
class ZMQRecvOp: public AsyncOpKernel { class ZMQRecvOp: public AsyncOpKernel {
public: public:
explicit ZMQRecvOp(OpKernelConstruction* context) : AsyncOpKernel(context) { explicit ZMQRecvOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_)); OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_));
CHECK_EQ(conn_.get(), nullptr);
string endpoint;
OP_REQUIRES_OK(context, context->GetAttr("end_point", &endpoint));
int hwm;
OP_REQUIRES_OK(context, context->GetAttr("hwm", &hwm));
// will get called only at the first sess.run call
conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL, hwm));
} }
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
//GuardedTimer tm("Compute"); //GuardedTimer tm("Compute");
int start, stop; ZMQConnection* conn = nullptr;
OP_REQUIRES_OK_ASYNC(ctx, this->OutputRange("output", &start, &stop), done); OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &conn), done);
RecvTensorList tlist; RecvTensorList tlist;
conn_->recv_tensor_list(&tlist); conn->recv_tensor_list(&tlist);
auto& tensors = tlist.tensors; auto& tensors = tlist.tensors;
OpOutputList outputs;
OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", &outputs), done);
CHECK(tensors.size() == num_components()); CHECK(tensors.size() == num_components());
for (int i = start; i < stop; ++i) { for (int i = 0; i < tensors.size(); ++i) {
Tensor* output = nullptr; Tensor* output = nullptr;
int j = i - start; auto recv_dtype = tensors[i].dtype;
auto recv_dtype = tensors[j].dtype;
OP_REQUIRES_ASYNC( OP_REQUIRES_ASYNC(
ctx, component_types_[j] == recv_dtype, ctx, component_types_[i] == recv_dtype,
errors::InvalidArgument("Type mismatch at index ", std::to_string(j), errors::InvalidArgument("Type mismatch at index ", std::to_string(i),
" between received tensor (", DataTypeString(recv_dtype), " between received tensor (", DataTypeString(recv_dtype),
") and dtype (", DataTypeString(component_types_[j]), ")"), ") and dtype (", DataTypeString(component_types_[i]), ")"),
done); done);
TensorShape& shape = tensors[j].shape; TensorShape& shape = tensors[i].shape;
OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(i, shape, &output), done); OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(i, shape, &output), done);
// reinterpret cast and then memcpy
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()}).data(); auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()}).data();
memcpy(ptr, tensors[j].buf, tensors[j].buf_size); memcpy(ptr, tensors[i].buf, tensors[i].buf_size);
outputs.set(j, *output); ctx->set_output(i, *output);
} }
done(); done();
} }
private: private:
DataTypeVector component_types_; DataTypeVector component_types_;
unique_ptr<ZMQConnection> conn_;
size_t num_components() const { return component_types_.size(); } size_t num_components() const { return component_types_.size(); }
}; };
REGISTER_KERNEL_BUILDER(Name("ZMQRecv").Device(DEVICE_CPU), ZMQRecvOp); REGISTER_KERNEL_BUILDER(Name("ZMQRecv").Device(DEVICE_CPU), ZMQRecvOp);
REGISTER_KERNEL_BUILDER(Name("ZMQConnection").Device(DEVICE_CPU), ZMQConnectionHandleOp);
} // namespace tensorpack } // namespace tensorpack
REGISTER_OP("ZMQRecv")
.Input("handle: resource")
.Output("output: types")
.Attr("types: list(type) >= 1")
.SetShapeFn(shape_inference::UnknownShape)
.SetIsStateful()
.Doc(R"doc(
Receive a list of Tensors from a ZMQ connection handle.
The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc");
REGISTER_OP("ZMQConnection")
.Output("handle: resource")
.Attr("end_point: string")
.Attr("hwm: int >= 1 = 10")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Opens a ZMQ PULL socket and returns a handle to it as a resource.
end_point: the ZMQ end point.
hwm: ZMQ high-water mark.
container: If non-empty, this queue is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this queue will be shared under the given name
across multiple sessions.
)doc");
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment