Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
65c8b239
Commit
65c8b239
authored
Dec 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ZMQ] use AsyncOpKernel; better tests; use mutex. (#362)
parent
a0d60a64
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
90 additions
and
34 deletions
+90
-34
tensorpack/user_ops/test-recv-op.py
tensorpack/user_ops/test-recv-op.py
+55
-20
tensorpack/user_ops/zmq_conn.h
tensorpack/user_ops/zmq_conn.h
+17
-4
tensorpack/user_ops/zmq_recv_op.cc
tensorpack/user_ops/zmq_recv_op.cc
+18
-10
No files found.
tensorpack/user_ops/test-recv-op.py
View file @
65c8b239
...
...
@@ -3,10 +3,11 @@
# File: test-recv-op.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
sys
import
os
import
zmq
import
argparse
import
multiprocessing
as
mp
import
time
import
numpy
as
np
os
.
environ
[
'TF_CPP_MIN_LOG_LEVEL'
]
=
'2'
import
tensorflow
as
tf
# noqa
...
...
@@ -19,27 +20,46 @@ from tensorpack.utils.concurrency import ( # noqa
ENDPOINT
=
'ipc://test-pipe'
if
__name__
==
'__main__'
:
try
:
num
=
int
(
sys
.
argv
[
1
])
except
(
ValueError
,
IndexError
):
num
=
10
DATA
=
[]
def
send
(
iterable
,
delay
=
0
):
ctx
=
zmq
.
Context
()
sok
=
ctx
.
socket
(
zmq
.
PUSH
)
sok
.
bind
(
ENDPOINT
)
for
dp
in
iterable
:
if
delay
>
0
:
time
.
sleep
(
delay
)
print
(
"Sending data to socket.."
)
sok
.
send
(
dumps_zmq_op
(
dp
))
time
.
sleep
(
999
)
def
random_array
(
num
):
ret
=
[]
for
k
in
range
(
num
):
arr1
=
np
.
random
.
rand
(
k
+
10
,
k
+
10
)
.
astype
(
'float32'
)
arr2
=
(
np
.
random
.
rand
((
k
+
10
)
*
2
)
*
10
)
.
astype
(
'uint8'
)
DATA
.
append
([
arr1
,
arr2
])
ret
.
append
([
arr1
,
arr2
])
return
ret
def
hash_dp
(
dp
):
return
sum
([
k
.
sum
()
for
k
in
dp
])
def
send
():
ctx
=
zmq
.
Context
()
sok
=
ctx
.
socket
(
zmq
.
PUSH
)
sok
.
connect
(
ENDPOINT
)
for
dp
in
DATA
:
sok
.
send
(
dumps_zmq_op
(
dp
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--task'
,
default
=
'basic'
,
choices
=
[
'basic'
,
'tworecv'
])
parser
.
add_argument
(
'-n'
,
'--num'
,
type
=
int
,
default
=
10
)
args
=
parser
.
parse_args
()
if
args
.
task
==
'basic'
:
DATA
=
random_array
(
args
.
num
)
p
=
mp
.
Process
(
target
=
send
,
args
=
(
DATA
,))
ensure_proc_terminate
(
p
)
start_proc_mask_signal
(
p
)
def
recv
():
sess
=
tf
.
Session
()
recv
=
zmq_recv
(
ENDPOINT
,
[
tf
.
float32
,
tf
.
uint8
])
print
(
recv
)
...
...
@@ -49,8 +69,23 @@ if __name__ == '__main__':
assert
(
arr
[
0
]
==
truth
[
0
])
.
all
()
assert
(
arr
[
1
]
==
truth
[
1
])
.
all
()
p
=
mp
.
Process
(
target
=
send
)
ensure_proc_terminate
(
p
)
start_proc_mask_signal
(
p
)
recv
()
p
.
join
()
p
.
join
()
if
args
.
task
==
'tworecv'
:
DATA
=
random_array
(
args
.
num
)
hashes
=
[
hash_dp
(
dp
)
for
dp
in
DATA
]
print
(
hashes
)
p
=
mp
.
Process
(
target
=
send
,
args
=
(
DATA
,
0.00
))
ensure_proc_terminate
(
p
)
start_proc_mask_signal
(
p
)
sess
=
tf
.
Session
()
recv1
=
zmq_recv
(
ENDPOINT
,
[
tf
.
float32
,
tf
.
uint8
],
hwm
=
1
)
recv2
=
zmq_recv
(
ENDPOINT
,
[
tf
.
float32
,
tf
.
uint8
],
hwm
=
1
)
print
(
recv1
,
recv2
)
for
i
in
range
(
args
.
num
//
2
):
res1
,
res2
=
sess
.
run
([
recv1
,
recv2
])
h1
,
h2
=
hash_dp
(
res1
),
hash_dp
(
res2
)
print
(
"Recv "
,
i
,
h1
,
h2
)
assert
h1
in
hashes
and
h2
in
hashes
tensorpack/user_ops/zmq_conn.h
View file @
65c8b239
...
...
@@ -7,6 +7,7 @@
#include <iostream>
#include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h>
#include <tensorflow/core/platform/mutex.h>
#include "zmq.hpp"
namespace
{
...
...
@@ -17,6 +18,8 @@ inline int read_int32(char** p) {
}
}
namespace
tensorpack
{
struct
RecvTensorList
{
zmq
::
message_t
message
;
...
...
@@ -35,13 +38,19 @@ class ZMQConnection {
ZMQConnection
(
std
::
string
endpoint
,
int
zmq_socket_type
,
int
hwm
)
:
ctx_
(
1
),
sock_
(
ctx_
,
zmq_socket_type
)
{
sock_
.
setsockopt
(
ZMQ_RCVHWM
,
&
hwm
,
sizeof
hwm
);
sock_
.
bind
(
endpoint
.
c_str
());
sock_
.
connect
(
endpoint
.
c_str
());
}
void
recv_tensor_list
(
RecvTensorList
*
tlist
)
{
// TODO critical section
bool
succ
=
sock_
.
recv
(
&
tlist
->
message
);
CHECK
(
succ
);
// no EAGAIN, because we are blocking
{
// https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels
// zmq socket is not thread safe
tensorflow
::
mutex_lock
lk
(
mu_
);
bool
succ
=
sock_
.
recv
(
&
tlist
->
message
);
// TODO this may throw
// possible error code: http://api.zeromq.org/3-3:zmq-msg-recv
// succ=false only if EAGAIN
CHECK
(
succ
);
// no EAGAIN, because we are blocking
}
char
*
pos
=
reinterpret_cast
<
char
*>
(
tlist
->
message
.
data
());
...
...
@@ -67,6 +76,10 @@ class ZMQConnection {
}
private:
tensorflow
::
mutex
mu_
;
zmq
::
context_t
ctx_
;
zmq
::
socket_t
sock_
;
};
}
// namespace tensorpack
tensorpack/user_ops/zmq_recv_op.cc
View file @
65c8b239
...
...
@@ -16,18 +16,21 @@ REGISTER_OP("ZMQRecv")
.
Output
(
"output: types"
)
.
Attr
(
"end_point: string"
)
.
Attr
(
"types: list(type) >= 1"
)
.
Attr
(
"hwm: int >= 1 = 10
0
"
)
.
Attr
(
"hwm: int >= 1 = 10"
)
.
SetShapeFn
(
shape_inference
::
UnknownShape
)
.
SetIsStateful
()
.
Doc
(
R"doc(
Receive a list of Tensors
from a ZMQ socke
t.
Receive a list of Tensors
by connecting to a ZMQ socket and pull from i
t.
The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc"
);
class
ZMQRecvOp
:
public
OpKernel
{
namespace
tensorpack
{
class
ZMQRecvOp
:
public
AsyncOpKernel
{
public:
explicit
ZMQRecvOp
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{
explicit
ZMQRecvOp
(
OpKernelConstruction
*
context
)
:
Async
OpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"types"
,
&
component_types_
));
CHECK_EQ
(
conn_
.
get
(),
nullptr
);
...
...
@@ -39,36 +42,37 @@ class ZMQRecvOp: public OpKernel {
conn_
.
reset
(
new
ZMQConnection
(
endpoint
,
ZMQ_PULL
,
hwm
));
}
void
Compute
(
OpKernelContext
*
ctx
)
override
{
void
Compute
Async
(
OpKernelContext
*
ctx
,
DoneCallback
done
)
override
{
//GuardedTimer tm("Compute");
int
start
,
stop
;
TF_CHECK_OK
(
this
->
OutputRange
(
"output"
,
&
start
,
&
stop
)
);
OP_REQUIRES_OK_ASYNC
(
ctx
,
this
->
OutputRange
(
"output"
,
&
start
,
&
stop
),
done
);
RecvTensorList
tlist
;
conn_
->
recv_tensor_list
(
&
tlist
);
auto
&
tensors
=
tlist
.
tensors
;
OpOutputList
outputs
;
OP_REQUIRES_OK
(
ctx
,
ctx
->
output_list
(
"output"
,
&
outputs
)
);
OP_REQUIRES_OK
_ASYNC
(
ctx
,
ctx
->
output_list
(
"output"
,
&
outputs
),
done
);
CHECK
(
tensors
.
size
()
==
num_components
());
for
(
int
i
=
start
;
i
<
stop
;
++
i
)
{
Tensor
*
output
=
nullptr
;
int
j
=
i
-
start
;
auto
recv_dtype
=
tensors
[
j
].
dtype
;
OP_REQUIRES
(
OP_REQUIRES
_ASYNC
(
ctx
,
component_types_
[
j
]
==
recv_dtype
,
errors
::
InvalidArgument
(
"Type mismatch between parsed tensor ("
,
DataTypeString
(
recv_dtype
),
") and dtype ("
,
DataTypeString
(
component_types_
[
j
]),
")"
));
DataTypeString
(
component_types_
[
j
]),
")"
)
,
done
);
TensorShape
&
shape
=
tensors
[
j
].
shape
;
OP_REQUIRES_OK
(
ctx
,
ctx
->
allocate_output
(
i
,
shape
,
&
output
)
);
OP_REQUIRES_OK
_ASYNC
(
ctx
,
ctx
->
allocate_output
(
i
,
shape
,
&
output
),
done
);
auto
ptr
=
output
->
bit_casted_shaped
<
char
,
1
>
({
shape
.
num_elements
()});
memcpy
(
ptr
.
data
(),
tensors
[
j
].
buf
,
tensors
[
j
].
size
);
outputs
.
set
(
j
,
*
output
);
}
done
();
}
private:
DataTypeVector
component_types_
;
...
...
@@ -77,4 +81,8 @@ class ZMQRecvOp: public OpKernel {
size_t
num_components
()
const
{
return
component_types_
.
size
();
}
};
REGISTER_KERNEL_BUILDER
(
Name
(
"ZMQRecv"
).
Device
(
DEVICE_CPU
),
ZMQRecvOp
);
}
// namespace tensorpack
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment