Commit d1ba5969 authored by Yuxin Wu's avatar Yuxin Wu

[ZMQ] more options for zmq socket. (#362)

parent 99ddd038
......@@ -21,18 +21,25 @@ else:
def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None):
"""
Run DataFlow and send data to a ZMQ socket addr.
It will dump and send each datapoint to this addr with a PUSH socket.
It will __connect__ to this addr,
serialize and send each datapoint to this addr with a PUSH socket.
This function never returns unless an error is encountered.
Args:
df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket addr.
addr: a ZMQ socket endpoint.
hwm (int): ZMQ high-water mark (buffer size)
format (str): The serialization format.
Default format would use :mod:`tensorpack.utils.serialize` (i.e. msgpack).
An alternate format is 'zmq_op'.
"""
# format (str): The serialization format. ZMQ Op is still not publicly usable now
# Default format would use :mod:`tensorpack.utils.serialize`.
# dump_fn = dumps if format is None else dumps_for_tfop
dump_fn = dumps
assert format in [None, 'zmq_op']
if format is None:
dump_fn = dumps
else:
from ..user_ops.zmq_recv import dumps_zmq_op
dump_fn = dumps_zmq_op
ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH)
socket.set_hwm(hwm)
......
......@@ -6,9 +6,11 @@
#include <string>
#include <iostream>
#include <thread>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/lib/strings/strcat.h>
#include <tensorflow/core/platform/mutex.h>
#include "zmq.hpp"
......@@ -20,7 +22,7 @@ inline int read_int32(char** p) {
}
inline tensorflow::int64 read_int64(char** p) {
auto pi = reinterpret_cast<const long long*>(*p);
auto pi = reinterpret_cast<const tensorflow::int64*>(*p);
*p += 8;
return *pi;
}
......@@ -28,6 +30,17 @@ inline tensorflow::int64 read_int64(char** p) {
namespace tensorpack {
struct ZMQSocketDef {
std::string end_point;
int socket_type, // ZMQ_PULL
hwm;
bool bind; // bind or connect
std::string DebugString() const {
return tensorflow::strings::StrCat("EndPoint=", end_point, ", hwm=", std::to_string(hwm));
}
};
struct RecvTensorList {
zmq::message_t message;
......@@ -43,13 +56,20 @@ struct RecvTensorList {
class ZMQConnection : public tensorflow::ResourceBase {
public:
ZMQConnection(std::string endpoint, int zmq_socket_type, int hwm):
ctx_(1), sock_(ctx_, zmq_socket_type) {
sock_.setsockopt(ZMQ_RCVHWM, &hwm, sizeof hwm);
sock_.bind(endpoint.c_str());
explicit ZMQConnection(const ZMQSocketDef& def):
def_{def}, ctx_{1}, sock_{ctx_, def.socket_type} {
int linger = 0;
sock_.setsockopt(ZMQ_LINGER, &linger , sizeof linger);
sock_.setsockopt(ZMQ_RCVHWM, &def.hwm , sizeof def.hwm);
if (def.bind) {
sock_.bind(def.end_point.c_str());
} else {
sock_.connect(def.end_point.c_str());
}
}
std::string DebugString() override { return ""; }
std::string DebugString() override { return def_.DebugString(); }
void recv_tensor_list(RecvTensorList* tlist) {
{
......@@ -86,7 +106,10 @@ class ZMQConnection : public tensorflow::ResourceBase {
}
}
const ZMQSocketDef& get_socket_def() const { return def_; }
private:
ZMQSocketDef def_;
tensorflow::mutex mu_;
zmq::context_t ctx_;
zmq::socket_t sock_;
......
......@@ -27,15 +27,17 @@ class ZMQConnectionHandleOp : public ResourceOpKernel<ZMQConnection> {
private:
Status CreateResource(ZMQConnection** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) {
const NodeDef& ndef = def();
string end_point;
int hwm;
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "end_point", &end_point));
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "hwm", &hwm));
*ret = new ZMQConnection(end_point, ZMQ_PULL, hwm);
ZMQSocketDef sockdef;
sockdef.socket_type = ZMQ_PULL;
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "bind", &sockdef.bind));
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "end_point", &sockdef.end_point));
TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "hwm", &sockdef.hwm));
*ret = new ZMQConnection(sockdef);
return Status::OK();
}
// TODO verify
// Can verify, but probably not necessary because python is not going to eval this op twice with
// the same shared name
};
......@@ -46,7 +48,6 @@ class ZMQRecvOp: public AsyncOpKernel {
}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
//GuardedTimer tm("Compute");
ZMQConnection* conn = nullptr;
OP_REQUIRES_OK_ASYNC(
ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &conn), done);
......@@ -105,6 +106,7 @@ REGISTER_OP("ZMQConnection")
.Output("handle: resource")
.Attr("end_point: string")
.Attr("hwm: int >= 1 = 10")
.Attr("bind: bool = true")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
......@@ -115,8 +117,7 @@ REGISTER_OP("ZMQConnection")
Opens a ZMQ PULL socket and returns a handle to it as a resource.
end_point: the ZMQ end point.
hwm: ZMQ high-water mark.
container: If non-empty, this queue is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this queue will be shared under the given name
across multiple sessions.
bind: If false, will connect to the endpoint rather than bind to it.
container: required for a resource op kernel.
shared_name: If non-empty, this connection will be shared under the given name across multiple sessions.
)doc");
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment