Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d1ba5969
Commit
d1ba5969
authored
Dec 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ZMQ] more options for zmq socket. (#362)
parent
99ddd038
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
24 deletions
+55
-24
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+13
-6
tensorpack/user_ops/zmq_conn.h
tensorpack/user_ops/zmq_conn.h
+30
-7
tensorpack/user_ops/zmq_recv_op.cc
tensorpack/user_ops/zmq_recv_op.cc
+12
-11
No files found.
tensorpack/dataflow/remote.py
View file @
d1ba5969
...
@@ -21,18 +21,25 @@ else:
...
@@ -21,18 +21,25 @@ else:
def
send_dataflow_zmq
(
df
,
addr
,
hwm
=
50
,
print_interval
=
100
,
format
=
None
):
def
send_dataflow_zmq
(
df
,
addr
,
hwm
=
50
,
print_interval
=
100
,
format
=
None
):
"""
"""
Run DataFlow and send data to a ZMQ socket addr.
Run DataFlow and send data to a ZMQ socket addr.
It will dump and send each datapoint to this addr with a PUSH socket.
It will __connect__ to this addr,
serialize and send each datapoint to this addr with a PUSH socket.
This function never returns unless an error is encountered.
This function never returns unless an error is encountered.
Args:
Args:
df (DataFlow): Will infinitely loop over the DataFlow.
df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket
addr
.
addr: a ZMQ socket
endpoint
.
hwm (int): ZMQ high-water mark (buffer size)
hwm (int): ZMQ high-water mark (buffer size)
format (str): The serialization format.
Default format would use :mod:`tensorpack.utils.serialize` (i.e. msgpack).
An alternate format is 'zmq_op'.
"""
"""
# format (str): The serialization format. ZMQ Op is still not publicly usable now
assert
format
in
[
None
,
'zmq_op'
]
# Default format would use :mod:`tensorpack.utils.serialize`.
if
format
is
None
:
# dump_fn = dumps if format is None else dumps_for_tfop
dump_fn
=
dumps
dump_fn
=
dumps
else
:
from
..user_ops.zmq_recv
import
dumps_zmq_op
dump_fn
=
dumps_zmq_op
ctx
=
zmq
.
Context
()
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PUSH
)
socket
=
ctx
.
socket
(
zmq
.
PUSH
)
socket
.
set_hwm
(
hwm
)
socket
.
set_hwm
(
hwm
)
...
...
tensorpack/user_ops/zmq_conn.h
View file @
d1ba5969
...
@@ -6,9 +6,11 @@
...
@@ -6,9 +6,11 @@
#include <string>
#include <string>
#include <iostream>
#include <iostream>
#include <thread>
#include <thread>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h>
#include <tensorflow/core/
framework/resource_mgr
.h>
#include <tensorflow/core/
lib/strings/strcat
.h>
#include <tensorflow/core/platform/mutex.h>
#include <tensorflow/core/platform/mutex.h>
#include "zmq.hpp"
#include "zmq.hpp"
...
@@ -20,7 +22,7 @@ inline int read_int32(char** p) {
...
@@ -20,7 +22,7 @@ inline int read_int32(char** p) {
}
}
inline
tensorflow
::
int64
read_int64
(
char
**
p
)
{
inline
tensorflow
::
int64
read_int64
(
char
**
p
)
{
auto
pi
=
reinterpret_cast
<
const
long
long
*>
(
*
p
);
auto
pi
=
reinterpret_cast
<
const
tensorflow
::
int64
*>
(
*
p
);
*
p
+=
8
;
*
p
+=
8
;
return
*
pi
;
return
*
pi
;
}
}
...
@@ -28,6 +30,17 @@ inline tensorflow::int64 read_int64(char** p) {
...
@@ -28,6 +30,17 @@ inline tensorflow::int64 read_int64(char** p) {
namespace
tensorpack
{
namespace
tensorpack
{
struct
ZMQSocketDef
{
std
::
string
end_point
;
int
socket_type
,
// ZMQ_PULL
hwm
;
bool
bind
;
// bind or connect
std
::
string
DebugString
()
const
{
return
tensorflow
::
strings
::
StrCat
(
"EndPoint="
,
end_point
,
", hwm="
,
std
::
to_string
(
hwm
));
}
};
struct
RecvTensorList
{
struct
RecvTensorList
{
zmq
::
message_t
message
;
zmq
::
message_t
message
;
...
@@ -43,13 +56,20 @@ struct RecvTensorList {
...
@@ -43,13 +56,20 @@ struct RecvTensorList {
class
ZMQConnection
:
public
tensorflow
::
ResourceBase
{
class
ZMQConnection
:
public
tensorflow
::
ResourceBase
{
public:
public:
ZMQConnection
(
std
::
string
endpoint
,
int
zmq_socket_type
,
int
hwm
)
:
explicit
ZMQConnection
(
const
ZMQSocketDef
&
def
)
:
ctx_
(
1
),
sock_
(
ctx_
,
zmq_socket_type
)
{
def_
{
def
},
ctx_
{
1
},
sock_
{
ctx_
,
def
.
socket_type
}
{
sock_
.
setsockopt
(
ZMQ_RCVHWM
,
&
hwm
,
sizeof
hwm
);
int
linger
=
0
;
sock_
.
bind
(
endpoint
.
c_str
());
sock_
.
setsockopt
(
ZMQ_LINGER
,
&
linger
,
sizeof
linger
);
sock_
.
setsockopt
(
ZMQ_RCVHWM
,
&
def
.
hwm
,
sizeof
def
.
hwm
);
if
(
def
.
bind
)
{
sock_
.
bind
(
def
.
end_point
.
c_str
());
}
else
{
sock_
.
connect
(
def
.
end_point
.
c_str
());
}
}
}
std
::
string
DebugString
()
override
{
return
""
;
}
std
::
string
DebugString
()
override
{
return
def_
.
DebugString
()
;
}
void
recv_tensor_list
(
RecvTensorList
*
tlist
)
{
void
recv_tensor_list
(
RecvTensorList
*
tlist
)
{
{
{
...
@@ -86,7 +106,10 @@ class ZMQConnection : public tensorflow::ResourceBase {
...
@@ -86,7 +106,10 @@ class ZMQConnection : public tensorflow::ResourceBase {
}
}
}
}
const
ZMQSocketDef
&
get_socket_def
()
const
{
return
def_
;
}
private:
private:
ZMQSocketDef
def_
;
tensorflow
::
mutex
mu_
;
tensorflow
::
mutex
mu_
;
zmq
::
context_t
ctx_
;
zmq
::
context_t
ctx_
;
zmq
::
socket_t
sock_
;
zmq
::
socket_t
sock_
;
...
...
tensorpack/user_ops/zmq_recv_op.cc
View file @
d1ba5969
...
@@ -27,15 +27,17 @@ class ZMQConnectionHandleOp : public ResourceOpKernel<ZMQConnection> {
...
@@ -27,15 +27,17 @@ class ZMQConnectionHandleOp : public ResourceOpKernel<ZMQConnection> {
private:
private:
Status
CreateResource
(
ZMQConnection
**
ret
)
override
EXCLUSIVE_LOCKS_REQUIRED
(
mu_
)
{
Status
CreateResource
(
ZMQConnection
**
ret
)
override
EXCLUSIVE_LOCKS_REQUIRED
(
mu_
)
{
const
NodeDef
&
ndef
=
def
();
const
NodeDef
&
ndef
=
def
();
string
end_point
;
ZMQSocketDef
sockdef
;
int
hwm
;
sockdef
.
socket_type
=
ZMQ_PULL
;
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
ndef
,
"end_point"
,
&
end_point
));
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
ndef
,
"bind"
,
&
sockdef
.
bind
));
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
ndef
,
"hwm"
,
&
hwm
));
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
ndef
,
"end_point"
,
&
sockdef
.
end_point
));
*
ret
=
new
ZMQConnection
(
end_point
,
ZMQ_PULL
,
hwm
);
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
ndef
,
"hwm"
,
&
sockdef
.
hwm
));
*
ret
=
new
ZMQConnection
(
sockdef
);
return
Status
::
OK
();
return
Status
::
OK
();
}
}
// TODO verify
// Can verify, but probably not necessary because python is not going to eval this op twice with
// the same shared name
};
};
...
@@ -46,7 +48,6 @@ class ZMQRecvOp: public AsyncOpKernel {
...
@@ -46,7 +48,6 @@ class ZMQRecvOp: public AsyncOpKernel {
}
}
void
ComputeAsync
(
OpKernelContext
*
ctx
,
DoneCallback
done
)
override
{
void
ComputeAsync
(
OpKernelContext
*
ctx
,
DoneCallback
done
)
override
{
//GuardedTimer tm("Compute");
ZMQConnection
*
conn
=
nullptr
;
ZMQConnection
*
conn
=
nullptr
;
OP_REQUIRES_OK_ASYNC
(
OP_REQUIRES_OK_ASYNC
(
ctx
,
LookupResource
(
ctx
,
HandleFromInput
(
ctx
,
0
),
&
conn
),
done
);
ctx
,
LookupResource
(
ctx
,
HandleFromInput
(
ctx
,
0
),
&
conn
),
done
);
...
@@ -105,6 +106,7 @@ REGISTER_OP("ZMQConnection")
...
@@ -105,6 +106,7 @@ REGISTER_OP("ZMQConnection")
.
Output
(
"handle: resource"
)
.
Output
(
"handle: resource"
)
.
Attr
(
"end_point: string"
)
.
Attr
(
"end_point: string"
)
.
Attr
(
"hwm: int >= 1 = 10"
)
.
Attr
(
"hwm: int >= 1 = 10"
)
.
Attr
(
"bind: bool = true"
)
.
Attr
(
"container: string = ''"
)
.
Attr
(
"container: string = ''"
)
.
Attr
(
"shared_name: string = ''"
)
.
Attr
(
"shared_name: string = ''"
)
...
@@ -115,8 +117,7 @@ REGISTER_OP("ZMQConnection")
...
@@ -115,8 +117,7 @@ REGISTER_OP("ZMQConnection")
Opens a ZMQ PULL socket and returns a handle to it as a resource.
Opens a ZMQ PULL socket and returns a handle to it as a resource.
end_point: the ZMQ end point.
end_point: the ZMQ end point.
hwm: ZMQ high-water mark.
hwm: ZMQ high-water mark.
container: If non-empty, this queue is placed in the given container.
bind: If false, will connect to the endpoint rather than bind to it.
Otherwise, a default container is used.
container: required for a resource op kernel.
shared_name: If non-empty, this queue will be shared under the given name
shared_name: If non-empty, this connection will be shared under the given name across multiple sessions.
across multiple sessions.
)doc"
);
)doc"
);
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment