Commit d1ba5969 authored by Yuxin Wu's avatar Yuxin Wu

[ZMQ] more options for zmq socket. (#362)

parent 99ddd038
...@@ -21,18 +21,25 @@ else: ...@@ -21,18 +21,25 @@ else:
def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None): def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None):
""" """
Run DataFlow and send data to a ZMQ socket addr. Run DataFlow and send data to a ZMQ socket addr.
It will dump and send each datapoint to this addr with a PUSH socket. It will __connect__ to this addr,
serialize and send each datapoint to this addr with a PUSH socket.
This function never returns unless an error is encountered. This function never returns unless an error is encountered.
Args: Args:
df (DataFlow): Will infinitely loop over the DataFlow. df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket addr. addr: a ZMQ socket endpoint.
hwm (int): ZMQ high-water mark (buffer size) hwm (int): ZMQ high-water mark (buffer size)
format (str): The serialization format.
Default format would use :mod:`tensorpack.utils.serialize` (i.e. msgpack).
An alternate format is 'zmq_op'.
""" """
# format (str): The serialization format. ZMQ Op is still not publicly usable now assert format in [None, 'zmq_op']
# Default format would use :mod:`tensorpack.utils.serialize`. if format is None:
# dump_fn = dumps if format is None else dumps_for_tfop
dump_fn = dumps dump_fn = dumps
else:
from ..user_ops.zmq_recv import dumps_zmq_op
dump_fn = dumps_zmq_op
ctx = zmq.Context() ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH) socket = ctx.socket(zmq.PUSH)
socket.set_hwm(hwm) socket.set_hwm(hwm)
......
...@@ -6,9 +6,11 @@ ...@@ -6,9 +6,11 @@
#include <string> #include <string>
#include <iostream> #include <iostream>
#include <thread> #include <thread>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/framework/tensor_shape.h> #include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h> #include <tensorflow/core/lib/gtl/inlined_vector.h>
#include <tensorflow/core/framework/resource_mgr.h> #include <tensorflow/core/lib/strings/strcat.h>
#include <tensorflow/core/platform/mutex.h> #include <tensorflow/core/platform/mutex.h>
#include "zmq.hpp" #include "zmq.hpp"
...@@ -20,7 +22,7 @@ inline int read_int32(char** p) { ...@@ -20,7 +22,7 @@ inline int read_int32(char** p) {
} }
inline tensorflow::int64 read_int64(char** p) { inline tensorflow::int64 read_int64(char** p) {
auto pi = reinterpret_cast<const long long*>(*p); auto pi = reinterpret_cast<const tensorflow::int64*>(*p);
*p += 8; *p += 8;
return *pi; return *pi;
} }
...@@ -28,6 +30,17 @@ inline tensorflow::int64 read_int64(char** p) { ...@@ -28,6 +30,17 @@ inline tensorflow::int64 read_int64(char** p) {
namespace tensorpack { namespace tensorpack {
struct ZMQSocketDef {
std::string end_point;
int socket_type, // ZMQ_PULL
hwm;
bool bind; // bind or connect
std::string DebugString() const {
return tensorflow::strings::StrCat("EndPoint=", end_point, ", hwm=", std::to_string(hwm));
}
};
struct RecvTensorList { struct RecvTensorList {
zmq::message_t message; zmq::message_t message;
...@@ -43,13 +56,20 @@ struct RecvTensorList { ...@@ -43,13 +56,20 @@ struct RecvTensorList {
class ZMQConnection : public tensorflow::ResourceBase { class ZMQConnection : public tensorflow::ResourceBase {
public: public:
ZMQConnection(std::string endpoint, int zmq_socket_type, int hwm): explicit ZMQConnection(const ZMQSocketDef& def):
ctx_(1), sock_(ctx_, zmq_socket_type) { def_{def}, ctx_{1}, sock_{ctx_, def.socket_type} {
sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm); int linger = 0;
sock_.bind(endpoint.c_str()); sock_.setsockopt(ZMQ_LINGER, &linger , sizeof linger);
sock_.setsockopt(ZMQ_RCVHWM, &def.hwm , sizeof def.hwm);
if (def.bind) {
sock_.bind(def.end_point.c_str());
} else {
sock_.connect(def.end_point.c_str());
}
} }
std::string DebugString() override { return ""; } std::string DebugString() override { return def_.DebugString(); }
void recv_tensor_list(RecvTensorList* tlist) { void recv_tensor_list(RecvTensorList* tlist) {
{ {
...@@ -86,7 +106,10 @@ class ZMQConnection : public tensorflow::ResourceBase { ...@@ -86,7 +106,10 @@ class ZMQConnection : public tensorflow::ResourceBase {
} }
} }
const ZMQSocketDef& get_socket_def() const { return def_; }
private: private:
ZMQSocketDef def_;
tensorflow::mutex mu_; tensorflow::mutex mu_;
zmq::context_t ctx_; zmq::context_t ctx_;
zmq::socket_t sock_; zmq::socket_t sock_;
......
...@@ -27,15 +27,17 @@ class ZMQConnectionHandleOp : public ResourceOpKernel<ZMQConnection> { ...@@ -27,15 +27,17 @@ class ZMQConnectionHandleOp : public ResourceOpKernel<ZMQConnection> {
private: private:
Status CreateResource(ZMQConnection** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) { Status CreateResource(ZMQConnection** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) {
const NodeDef& ndef = def(); const NodeDef& ndef = def();
string end_point; ZMQSocketDef sockdef;
int hwm; sockdef.socket_type = ZMQ_PULL;
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "end_point", &end_point)); TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "bind", &sockdef.bind));
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "hwm", &hwm)); TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "end_point", &sockdef.end_point));
*ret = new ZMQConnection(end_point, ZMQ_PULL, hwm); TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "hwm", &sockdef.hwm));
*ret = new ZMQConnection(sockdef);
return Status::OK(); return Status::OK();
} }
// TODO verify // Can verify, but probably not necessary because python is not going to eval this op twice with
// the same shared name
}; };
...@@ -46,7 +48,6 @@ class ZMQRecvOp: public AsyncOpKernel { ...@@ -46,7 +48,6 @@ class ZMQRecvOp: public AsyncOpKernel {
} }
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
//GuardedTimer tm("Compute");
ZMQConnection* conn = nullptr; ZMQConnection* conn = nullptr;
OP_REQUIRES_OK_ASYNC( OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &conn), done); ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &conn), done);
...@@ -105,6 +106,7 @@ REGISTER_OP("ZMQConnection") ...@@ -105,6 +106,7 @@ REGISTER_OP("ZMQConnection")
.Output("handle: resource") .Output("handle: resource")
.Attr("end_point: string") .Attr("end_point: string")
.Attr("hwm: int >= 1 = 10") .Attr("hwm: int >= 1 = 10")
.Attr("bind: bool = true")
.Attr("container: string = ''") .Attr("container: string = ''")
.Attr("shared_name: string = ''") .Attr("shared_name: string = ''")
...@@ -115,8 +117,7 @@ REGISTER_OP("ZMQConnection") ...@@ -115,8 +117,7 @@ REGISTER_OP("ZMQConnection")
Opens a ZMQ PULL socket and returns a handle to it as a resource. Opens a ZMQ PULL socket and returns a handle to it as a resource.
end_point: the ZMQ end point. end_point: the ZMQ end point.
hwm: ZMQ high-water mark. hwm: ZMQ high-water mark.
container: If non-empty, this queue is placed in the given container. bind: If false, will connect to the endpoint rather than bind to it.
Otherwise, a default container is used. container: required for a resource op kernel.
shared_name: If non-empty, this queue will be shared under the given name shared_name: If non-empty, this connection will be shared under the given name across multiple sessions.
across multiple sessions.
)doc"); )doc");
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment