Commit 99ddd038 authored by Yuxin Wu's avatar Yuxin Wu

[ZMQ] use resource for ZMQ connection

parent 0594a9ad
......@@ -12,7 +12,7 @@ import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf # noqa
from tensorpack.user_ops.zmq_recv import ( # noqa
zmq_recv, dumps_zmq_op)
ZMQRecv, dumps_zmq_op)
from tensorpack.utils.concurrency import ( # noqa
start_proc_mask_signal,
ensure_proc_terminate)
......@@ -24,7 +24,7 @@ ENDPOINT = 'ipc://test-pipe'
def send(iterable, delay=0):
ctx = zmq.Context()
sok = ctx.socket(zmq.PUSH)
sok.bind(ENDPOINT)
sok.connect(ENDPOINT)
for dp in iterable:
if delay > 0:
......@@ -68,7 +68,7 @@ if __name__ == '__main__':
start_proc_mask_signal(p)
sess = tf.Session()
recv = zmq_recv(ENDPOINT, [tf.float32, tf.uint8])
recv = ZMQRecv(ENDPOINT, [tf.float32, tf.uint8]).recv()
print(recv)
for truth in DATA:
......@@ -87,8 +87,9 @@ if __name__ == '__main__':
start_proc_mask_signal(p)
sess = tf.Session()
recv1 = zmq_recv(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
recv2 = zmq_recv(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
zmqsock = ZMQRecv(ENDPOINT, [tf.float32, tf.uint8], hwm=1)
recv1 = zmqsock.recv()
recv2 = zmqsock.recv()
print(recv1, recv2)
for i in range(args.num // 2):
......
......@@ -5,8 +5,10 @@
#include <string>
#include <iostream>
#include <thread>
#include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/platform/mutex.h>
#include "zmq.hpp"
......@@ -39,14 +41,16 @@ struct RecvTensorList {
tensorflow::gtl::InlinedVector<TensorConstructor, 4> tensors;
};
class ZMQConnection {
class ZMQConnection : public tensorflow::ResourceBase {
public:
ZMQConnection(std::string endpoint, int zmq_socket_type, int hwm):
ctx_(1), sock_(ctx_, zmq_socket_type) {
sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm);
sock_.connect(endpoint.c_str());
sock_.bind(endpoint.c_str());
}
std::string DebugString() override { return ""; }
void recv_tensor_list(RecvTensorList* tlist) {
{
// https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels
......
......@@ -13,12 +13,14 @@ from tensorflow.core.framework import types_pb2 as DataType
from .common import compile, get_ext_suffix
__all__ = ['zmq_recv', 'dumps_zmq_op',
__all__ = ['dumps_zmq_op', 'ZMQRecv',
'dump_tensor_protos', 'to_tensor_proto']
def build():
global zmq_recv
_zmq_recv_mod = None
def try_build():
file_dir = os.path.dirname(os.path.abspath(__file__))
basename = 'zmq_recv_op' + get_ext_suffix()
so_file = os.path.join(file_dir, basename)
......@@ -27,11 +29,33 @@ def build():
if ret != 0:
raise RuntimeError("tensorpack user_ops compilation failed!")
recv_mod = tf.load_op_library(so_file)
zmq_recv = recv_mod.zmq_recv
global _zmq_recv_mod
_zmq_recv_mod = tf.load_op_library(so_file)
try_build()
class ZMQRecv(object):
def __init__(self, end_point, types, hwm=None, name=None):
self._types = types
if name is None:
self._name = (tf.get_default_graph()
.unique_name(self.__class__.__name__))
else:
self._name = name
self._zmq_handle = _zmq_recv_mod.zmq_connection(
end_point, hwm, shared_name=self._name)
@property
def name(self):
return self._name
build()
def recv(self):
return _zmq_recv_mod.zmq_recv(
self._zmq_handle, self._types)
_DTYPE_DICT = {
......
......@@ -3,87 +3,120 @@
#include <string>
#include <memory>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include <tensorflow/core/framework/op.h>
#include <tensorflow/core/framework/op_kernel.h>
#include <tensorflow/core/framework/resource_op_kernel.h>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/framework/common_shape_fns.h>
#include "zmq_conn.h"
using namespace std;
using namespace tensorflow;
REGISTER_OP("ZMQRecv")
.Output("output: types")
.Attr("end_point: string")
.Attr("types: list(type) >= 1")
.Attr("hwm: int >= 1 = 10")
.SetShapeFn(shape_inference::UnknownShape)
.SetIsStateful()
.Doc(R"doc(
Receive a list of Tensors by connecting to a ZMQ socket and pull from it.
The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc");
namespace tensorpack {
// An op to create zmq connection as a resource.
// Use ResourceOpKernel to ensure singleton construction.
class ZMQConnectionHandleOp : public ResourceOpKernel<ZMQConnection> {
public:
explicit ZMQConnectionHandleOp(OpKernelConstruction* ctx)
: ResourceOpKernel<ZMQConnection>(ctx) {}
private:
Status CreateResource(ZMQConnection** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) {
const NodeDef& ndef = def();
string end_point;
int hwm;
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "end_point", &end_point));
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "hwm", &hwm));
*ret = new ZMQConnection(end_point, ZMQ_PULL, hwm);
return Status::OK();
}
// TODO verify
};
namespace tensorpack {
class ZMQRecvOp: public AsyncOpKernel {
public:
explicit ZMQRecvOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("types", &component_types_));
CHECK_EQ(conn_.get(), nullptr);
string endpoint;
OP_REQUIRES_OK(context, context->GetAttr("end_point", &endpoint));
int hwm;
OP_REQUIRES_OK(context, context->GetAttr("hwm", &hwm));
// will get called only at the first sess.run call
conn_.reset(new ZMQConnection(endpoint, ZMQ_PULL, hwm));
}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
//GuardedTimer tm("Compute");
int start, stop;
OP_REQUIRES_OK_ASYNC(ctx, this->OutputRange("output", &start, &stop), done);
ZMQConnection* conn = nullptr;
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &conn), done);
RecvTensorList tlist;
conn_->recv_tensor_list(&tlist);
conn->recv_tensor_list(&tlist);
auto& tensors = tlist.tensors;
OpOutputList outputs;
OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", &outputs), done);
CHECK(tensors.size() == num_components());
for (int i = start; i < stop; ++i) {
for (int i = 0; i < tensors.size(); ++i) {
Tensor* output = nullptr;
int j = i - start;
auto recv_dtype = tensors[j].dtype;
auto recv_dtype = tensors[i].dtype;
OP_REQUIRES_ASYNC(
ctx, component_types_[j] == recv_dtype,
errors::InvalidArgument("Type mismatch at index ", std::to_string(j),
ctx, component_types_[i] == recv_dtype,
errors::InvalidArgument("Type mismatch at index ", std::to_string(i),
" between received tensor (", DataTypeString(recv_dtype),
") and dtype (", DataTypeString(component_types_[j]), ")"),
") and dtype (", DataTypeString(component_types_[i]), ")"),
done);
TensorShape& shape = tensors[j].shape;
TensorShape& shape = tensors[i].shape;
OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(i, shape, &output), done);
// reinterpret cast and then memcpy
auto ptr = output->bit_casted_shaped<char, 1>({shape.num_elements()}).data();
memcpy(ptr, tensors[j].buf, tensors[j].buf_size);
outputs.set(j, *output);
memcpy(ptr, tensors[i].buf, tensors[i].buf_size);
ctx->set_output(i, *output);
}
done();
}
private:
DataTypeVector component_types_;
unique_ptr<ZMQConnection> conn_;
size_t num_components() const { return component_types_.size(); }
};
REGISTER_KERNEL_BUILDER(Name("ZMQRecv").Device(DEVICE_CPU), ZMQRecvOp);
REGISTER_KERNEL_BUILDER(Name("ZMQConnection").Device(DEVICE_CPU), ZMQConnectionHandleOp);
} // namespace tensorpack
REGISTER_OP("ZMQRecv")
.Input("handle: resource")
.Output("output: types")
.Attr("types: list(type) >= 1")
.SetShapeFn(shape_inference::UnknownShape)
.SetIsStateful()
.Doc(R"doc(
Receive a list of Tensors from a ZMQ connection handle.
The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc");
REGISTER_OP("ZMQConnection")
.Output("handle: resource")
.Attr("end_point: string")
.Attr("hwm: int >= 1 = 10")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
Opens a ZMQ PULL socket and returns a handle to it as a resource.
end_point: the ZMQ end point.
hwm: ZMQ high-water mark.
container: If non-empty, this queue is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this queue will be shared under the given name
across multiple sessions.
)doc");
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment