Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
99ddd038
Commit
99ddd038
authored
Dec 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ZMQ] use resource for ZMQ connection
parent
0594a9ad
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
53 deletions
+115
-53
tensorpack/user_ops/test-recv-op.py
tensorpack/user_ops/test-recv-op.py
+6
-5
tensorpack/user_ops/zmq_conn.h
tensorpack/user_ops/zmq_conn.h
+6
-2
tensorpack/user_ops/zmq_recv.py
tensorpack/user_ops/zmq_recv.py
+30
-6
tensorpack/user_ops/zmq_recv_op.cc
tensorpack/user_ops/zmq_recv_op.cc
+73
-40
No files found.
tensorpack/user_ops/test-recv-op.py
View file @
99ddd038
...
...
@@ -12,7 +12,7 @@ import numpy as np
os
.
environ
[
'TF_CPP_MIN_LOG_LEVEL'
]
=
'2'
import
tensorflow
as
tf
# noqa
from
tensorpack.user_ops.zmq_recv
import
(
# noqa
zmq_r
ecv
,
dumps_zmq_op
)
ZMQR
ecv
,
dumps_zmq_op
)
from
tensorpack.utils.concurrency
import
(
# noqa
start_proc_mask_signal
,
ensure_proc_terminate
)
...
...
@@ -24,7 +24,7 @@ ENDPOINT = 'ipc://test-pipe'
def
send
(
iterable
,
delay
=
0
):
ctx
=
zmq
.
Context
()
sok
=
ctx
.
socket
(
zmq
.
PUSH
)
sok
.
bind
(
ENDPOINT
)
sok
.
connect
(
ENDPOINT
)
for
dp
in
iterable
:
if
delay
>
0
:
...
...
@@ -68,7 +68,7 @@ if __name__ == '__main__':
start_proc_mask_signal
(
p
)
sess
=
tf
.
Session
()
recv
=
zmq_recv
(
ENDPOINT
,
[
tf
.
float32
,
tf
.
uint8
]
)
recv
=
ZMQRecv
(
ENDPOINT
,
[
tf
.
float32
,
tf
.
uint8
])
.
recv
(
)
print
(
recv
)
for
truth
in
DATA
:
...
...
@@ -87,8 +87,9 @@ if __name__ == '__main__':
start_proc_mask_signal
(
p
)
sess
=
tf
.
Session
()
recv1
=
zmq_recv
(
ENDPOINT
,
[
tf
.
float32
,
tf
.
uint8
],
hwm
=
1
)
recv2
=
zmq_recv
(
ENDPOINT
,
[
tf
.
float32
,
tf
.
uint8
],
hwm
=
1
)
zmqsock
=
ZMQRecv
(
ENDPOINT
,
[
tf
.
float32
,
tf
.
uint8
],
hwm
=
1
)
recv1
=
zmqsock
.
recv
()
recv2
=
zmqsock
.
recv
()
print
(
recv1
,
recv2
)
for
i
in
range
(
args
.
num
//
2
):
...
...
tensorpack/user_ops/zmq_conn.h
View file @
99ddd038
...
...
@@ -5,8 +5,10 @@
#include <string>
#include <iostream>
#include <thread>
#include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/platform/mutex.h>
#include "zmq.hpp"
...
...
@@ -39,14 +41,16 @@ struct RecvTensorList {
tensorflow
::
gtl
::
InlinedVector
<
TensorConstructor
,
4
>
tensors
;
};
class
ZMQConnection
{
class
ZMQConnection
:
public
tensorflow
::
ResourceBase
{
public:
ZMQConnection
(
std
::
string
endpoint
,
int
zmq_socket_type
,
int
hwm
)
:
ctx_
(
1
),
sock_
(
ctx_
,
zmq_socket_type
)
{
sock_
.
setsockopt
(
ZMQ_RCVHWM
,
&
hwm
,
sizeof
hwm
);
sock_
.
connect
(
endpoint
.
c_str
());
sock_
.
bind
(
endpoint
.
c_str
());
}
std
::
string
DebugString
()
override
{
return
""
;
}
void
recv_tensor_list
(
RecvTensorList
*
tlist
)
{
{
// https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels
...
...
tensorpack/user_ops/zmq_recv.py
View file @
99ddd038
...
...
@@ -13,12 +13,14 @@ from tensorflow.core.framework import types_pb2 as DataType
from
.common
import
compile
,
get_ext_suffix
__all__
=
[
'
zmq_recv'
,
'dumps_zmq_op
'
,
__all__
=
[
'
dumps_zmq_op'
,
'ZMQRecv
'
,
'dump_tensor_protos'
,
'to_tensor_proto'
]
def
build
():
global
zmq_recv
_zmq_recv_mod
=
None
def
try_build
():
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
basename
=
'zmq_recv_op'
+
get_ext_suffix
()
so_file
=
os
.
path
.
join
(
file_dir
,
basename
)
...
...
@@ -27,11 +29,33 @@ def build():
if
ret
!=
0
:
raise
RuntimeError
(
"tensorpack user_ops compilation failed!"
)
recv_mod
=
tf
.
load_op_library
(
so_file
)
zmq_recv
=
recv_mod
.
zmq_recv
global
_zmq_recv_mod
_zmq_recv_mod
=
tf
.
load_op_library
(
so_file
)
try_build
()
class
ZMQRecv
(
object
):
def
__init__
(
self
,
end_point
,
types
,
hwm
=
None
,
name
=
None
):
self
.
_types
=
types
if
name
is
None
:
self
.
_name
=
(
tf
.
get_default_graph
()
.
unique_name
(
self
.
__class__
.
__name__
))
else
:
self
.
_name
=
name
self
.
_zmq_handle
=
_zmq_recv_mod
.
zmq_connection
(
end_point
,
hwm
,
shared_name
=
self
.
_name
)
@
property
def
name
(
self
):
return
self
.
_name
build
()
def
recv
(
self
):
return
_zmq_recv_mod
.
zmq_recv
(
self
.
_zmq_handle
,
self
.
_types
)
_DTYPE_DICT
=
{
...
...
tensorpack/user_ops/zmq_recv_op.cc
View file @
99ddd038
...
...
@@ -3,87 +3,120 @@
#include <string>
#include <memory>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include <tensorflow/core/framework/op.h>
#include <tensorflow/core/framework/op_kernel.h>
#include <tensorflow/core/framework/resource_op_kernel.h>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/framework/common_shape_fns.h>
#include "zmq_conn.h"
using
namespace
std
;
using
namespace
tensorflow
;
REGISTER_OP
(
"ZMQRecv"
)
.
Output
(
"output: types"
)
.
Attr
(
"end_point: string"
)
.
Attr
(
"types: list(type) >= 1"
)
.
Attr
(
"hwm: int >= 1 = 10"
)
.
SetShapeFn
(
shape_inference
::
UnknownShape
)
.
SetIsStateful
()
.
Doc
(
R"doc(
Receive a list of Tensors by connecting to a ZMQ socket and pull from it.
The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc"
);
namespace
tensorpack
{
// An op to create zmq connection as a resource.
// Use ResourceOpKernel to ensure singleton construction.
class
ZMQConnectionHandleOp
:
public
ResourceOpKernel
<
ZMQConnection
>
{
public:
explicit
ZMQConnectionHandleOp
(
OpKernelConstruction
*
ctx
)
:
ResourceOpKernel
<
ZMQConnection
>
(
ctx
)
{}
private:
Status
CreateResource
(
ZMQConnection
**
ret
)
override
EXCLUSIVE_LOCKS_REQUIRED
(
mu_
)
{
const
NodeDef
&
ndef
=
def
();
string
end_point
;
int
hwm
;
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
ndef
,
"end_point"
,
&
end_point
));
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
ndef
,
"hwm"
,
&
hwm
));
*
ret
=
new
ZMQConnection
(
end_point
,
ZMQ_PULL
,
hwm
);
return
Status
::
OK
();
}
// TODO verify
};
namespace
tensorpack
{
class
ZMQRecvOp
:
public
AsyncOpKernel
{
public:
explicit
ZMQRecvOp
(
OpKernelConstruction
*
context
)
:
AsyncOpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"types"
,
&
component_types_
));
CHECK_EQ
(
conn_
.
get
(),
nullptr
);
string
endpoint
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"end_point"
,
&
endpoint
));
int
hwm
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"hwm"
,
&
hwm
));
// will get called only at the first sess.run call
conn_
.
reset
(
new
ZMQConnection
(
endpoint
,
ZMQ_PULL
,
hwm
));
}
void
ComputeAsync
(
OpKernelContext
*
ctx
,
DoneCallback
done
)
override
{
//GuardedTimer tm("Compute");
int
start
,
stop
;
OP_REQUIRES_OK_ASYNC
(
ctx
,
this
->
OutputRange
(
"output"
,
&
start
,
&
stop
),
done
);
ZMQConnection
*
conn
=
nullptr
;
OP_REQUIRES_OK_ASYNC
(
ctx
,
LookupResource
(
ctx
,
HandleFromInput
(
ctx
,
0
),
&
conn
),
done
);
RecvTensorList
tlist
;
conn
_
->
recv_tensor_list
(
&
tlist
);
conn
->
recv_tensor_list
(
&
tlist
);
auto
&
tensors
=
tlist
.
tensors
;
OpOutputList
outputs
;
OP_REQUIRES_OK_ASYNC
(
ctx
,
ctx
->
output_list
(
"output"
,
&
outputs
),
done
);
CHECK
(
tensors
.
size
()
==
num_components
());
for
(
int
i
=
start
;
i
<
stop
;
++
i
)
{
for
(
int
i
=
0
;
i
<
tensors
.
size
()
;
++
i
)
{
Tensor
*
output
=
nullptr
;
int
j
=
i
-
start
;
auto
recv_dtype
=
tensors
[
j
].
dtype
;
auto
recv_dtype
=
tensors
[
i
].
dtype
;
OP_REQUIRES_ASYNC
(
ctx
,
component_types_
[
j
]
==
recv_dtype
,
errors
::
InvalidArgument
(
"Type mismatch at index "
,
std
::
to_string
(
j
),
ctx
,
component_types_
[
i
]
==
recv_dtype
,
errors
::
InvalidArgument
(
"Type mismatch at index "
,
std
::
to_string
(
i
),
" between received tensor ("
,
DataTypeString
(
recv_dtype
),
") and dtype ("
,
DataTypeString
(
component_types_
[
j
]),
")"
),
") and dtype ("
,
DataTypeString
(
component_types_
[
i
]),
")"
),
done
);
TensorShape
&
shape
=
tensors
[
j
].
shape
;
TensorShape
&
shape
=
tensors
[
i
].
shape
;
OP_REQUIRES_OK_ASYNC
(
ctx
,
ctx
->
allocate_output
(
i
,
shape
,
&
output
),
done
);
// reinterpret cast and then memcpy
auto
ptr
=
output
->
bit_casted_shaped
<
char
,
1
>
({
shape
.
num_elements
()}).
data
();
memcpy
(
ptr
,
tensors
[
j
].
buf
,
tensors
[
j
].
buf_size
);
outputs
.
set
(
j
,
*
output
);
memcpy
(
ptr
,
tensors
[
i
].
buf
,
tensors
[
i
].
buf_size
);
ctx
->
set_output
(
i
,
*
output
);
}
done
();
}
private:
DataTypeVector
component_types_
;
unique_ptr
<
ZMQConnection
>
conn_
;
size_t
num_components
()
const
{
return
component_types_
.
size
();
}
};
REGISTER_KERNEL_BUILDER
(
Name
(
"ZMQRecv"
).
Device
(
DEVICE_CPU
),
ZMQRecvOp
);
REGISTER_KERNEL_BUILDER
(
Name
(
"ZMQConnection"
).
Device
(
DEVICE_CPU
),
ZMQConnectionHandleOp
);
}
// namespace tensorpack
REGISTER_OP
(
"ZMQRecv"
)
.
Input
(
"handle: resource"
)
.
Output
(
"output: types"
)
.
Attr
(
"types: list(type) >= 1"
)
.
SetShapeFn
(
shape_inference
::
UnknownShape
)
.
SetIsStateful
()
.
Doc
(
R"doc(
Receive a list of Tensors from a ZMQ connection handle.
The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc"
);
REGISTER_OP
(
"ZMQConnection"
)
.
Output
(
"handle: resource"
)
.
Attr
(
"end_point: string"
)
.
Attr
(
"hwm: int >= 1 = 10"
)
.
Attr
(
"container: string = ''"
)
.
Attr
(
"shared_name: string = ''"
)
.
SetIsStateful
()
.
SetShapeFn
(
shape_inference
::
ScalarShape
)
.
Doc
(
R"doc(
Opens a ZMQ PULL socket and returns a handle to it as a resource.
end_point: the ZMQ end point.
hwm: ZMQ high-water mark.
container: If non-empty, this queue is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this queue will be shared under the given name
across multiple sessions.
)doc"
);
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment