Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a0d60a64
Commit
a0d60a64
authored
Dec 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix build and tests of zmq op (#362)
parent
05494dd6
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
56 deletions
+73
-56
tensorpack/user_ops/Makefile
tensorpack/user_ops/Makefile
+11
-6
tensorpack/user_ops/common.py
tensorpack/user_ops/common.py
+5
-3
tensorpack/user_ops/test-recv-op.py
tensorpack/user_ops/test-recv-op.py
+39
-39
tensorpack/user_ops/zmq_conn.h
tensorpack/user_ops/zmq_conn.h
+1
-2
tensorpack/user_ops/zmq_recv.py
tensorpack/user_ops/zmq_recv.py
+10
-3
tensorpack/user_ops/zmq_recv_op.cc
tensorpack/user_ops/zmq_recv_op.cc
+7
-3
No files found.
tensorpack/user_ops/Makefile
View file @
a0d60a64
# $File: Makefile
# $Date: Tue
Oct 31 11:44:27 2017 +
0800
# $Date: Tue
Dec 12 18:04:22 2017 -
0800
OBJ_DIR
=
obj
PYTHON
=
python
UNAME_S
:=
$(
shell
uname
-s
)
ifeq
($(UNAME_S),Linux)
...
...
@@ -21,15 +22,19 @@ INCLUDE_DIR += $(shell pkg-config --cflags $(LIBS))
LDFLAGS
+=
$(
shell
pkg-config
$(LIBS)
--libs
)
CXXFLAGS
+=
$(INCLUDE_DIR)
CXXFLAGS
+=
-Wall
-Wextra
CXXFLAGS
+=
-Wall
-Wextra
-Wno-unused-parameter
-Wno-sign-compare
CXXFLAGS
+=
$(DEFINES)
-std
=
c++11
$(OPTFLAGS)
-fPIC
# TODO https://github.com/tensorflow/tensorflow/issues/1569
# You may need to disable this flag if you compile tensorflow yourself with gcc>=5
CXXFLAGS
+=
-D_GLIBCXX_USE_CXX11_ABI
=
0
ifneq
($(MAKECMDGOALS), clean)
TF_CXXFLAGS
?=
$(
shell
$(PYTHON)
-c
'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags(
)))
'
)
TF_LDFLAGS
?=
$(
shell
$(PYTHON)
-c
'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags(
)))
'
)
endif
CXXFLAGS
+=
$(TF_CXXFLAGS)
LDFLAGS
+=
$(OPTFLAGS)
LDFLAGS
+=
-shared
-fPIC
LDFLAGS
+=
$(TF_LDFLAGS)
ifeq
($(UNAME_S),Darwin)
LDFLAGS
+=
-Wl
,-undefined
-Wl
,dynamic_lookup
endif
...
...
@@ -41,7 +46,7 @@ OBJS = $(addprefix $(OBJ_DIR)/,$(ccSOURCES:.cc=.o))
DEPFILES
=
$(OBJS:.o=.d)
# TODO what about mac?
SO
=
$(ccSOURCES:.cc=.so)
SO
=
zmq_recv_op.so
.PHONY
:
all clean
...
...
tensorpack/user_ops/common.py
View file @
a0d60a64
...
...
@@ -9,10 +9,11 @@ import os
def
compile
():
# TODO check modtime?
include_dir
=
tf
.
sysconfig
.
get_include
(
)
cxxflags
=
' '
.
join
(
tf
.
sysconfig
.
get_compile_flags
())
ldflags
=
' '
.
join
(
tf
.
sysconfig
.
get_link_flags
()
)
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
compile_cmd
=
'INCLUDE_DIR="-isystem {}" make -C "{}"'
.
format
(
include_dir
,
file_dir
)
compile_cmd
=
'TF_CXXFLAGS="{}" TF_LDFLAGS="{}" make -C "{}"'
.
format
(
cxxflags
,
ldflags
,
file_dir
)
ret
=
os
.
system
(
compile_cmd
)
return
ret
...
...
@@ -20,6 +21,7 @@ def compile():
# https://github.com/uber/horovod/blob/10835d25eccf4b198a23a0795edddf0896f6563d/horovod/tensorflow/mpi_ops.py#L30-L40
def
get_ext_suffix
():
"""Determine library extension for various versions of Python."""
return
'.so'
# TODO
ext_suffix
=
sysconfig
.
get_config_var
(
'EXT_SUFFIX'
)
if
ext_suffix
:
return
ext_suffix
...
...
tensorpack/user_ops/test-recv-op.py
View file @
a0d60a64
...
...
@@ -11,46 +11,46 @@ import numpy as np
os
.
environ
[
'TF_CPP_MIN_LOG_LEVEL'
]
=
'2'
import
tensorflow
as
tf
# noqa
from
tensorpack.user_ops.zmq_recv
import
(
# noqa
zmq_recv
,
dump_tensor_protos
,
to_tensor_proto
)
zmq_recv
,
dumps_zmq_op
)
from
tensorpack.utils.concurrency
import
(
# noqa
start_proc_mask_signal
,
ensure_proc_terminate
)
try
:
num
=
int
(
sys
.
argv
[
1
])
except
ValueError
:
num
=
2
ENDPOINT
=
'ipc://test-pipe'
DATA
=
[]
for
k
in
range
(
num
):
arr1
=
np
.
random
.
rand
(
k
+
10
,
k
+
10
)
.
astype
(
'float32'
)
arr2
=
(
np
.
random
.
rand
((
k
+
10
)
*
2
)
*
10
)
.
astype
(
'uint8'
)
DATA
.
append
([
arr1
,
arr2
])
def
send
():
ctx
=
zmq
.
Context
()
sok
=
ctx
.
socket
(
zmq
.
PUSH
)
sok
.
connect
(
ENDPOINT
)
for
arr1
,
arr2
in
DATA
:
t1
=
to_tensor_proto
(
arr1
)
t2
=
to_tensor_proto
(
arr2
)
t
=
dump_tensor_protos
([
t1
,
t2
])
sok
.
send
(
t
)
def
recv
():
sess
=
tf
.
InteractiveSession
()
recv
=
zmq_recv
(
ENDPOINT
,
[
tf
.
float32
,
tf
.
uint8
])
print
(
recv
)
for
truth
in
DATA
:
arr
=
sess
.
run
(
recv
)
assert
(
arr
[
0
]
==
truth
[
0
])
.
all
()
assert
(
arr
[
1
]
==
truth
[
1
])
.
all
()
p
=
mp
.
Process
(
target
=
send
)
p
.
start
()
recv
()
p
.
join
()
if
__name__
==
'__main__'
:
try
:
num
=
int
(
sys
.
argv
[
1
])
except
(
ValueError
,
IndexError
):
num
=
10
DATA
=
[]
for
k
in
range
(
num
):
arr1
=
np
.
random
.
rand
(
k
+
10
,
k
+
10
)
.
astype
(
'float32'
)
arr2
=
(
np
.
random
.
rand
((
k
+
10
)
*
2
)
*
10
)
.
astype
(
'uint8'
)
DATA
.
append
([
arr1
,
arr2
])
def
send
():
ctx
=
zmq
.
Context
()
sok
=
ctx
.
socket
(
zmq
.
PUSH
)
sok
.
connect
(
ENDPOINT
)
for
dp
in
DATA
:
sok
.
send
(
dumps_zmq_op
(
dp
))
def
recv
():
sess
=
tf
.
Session
()
recv
=
zmq_recv
(
ENDPOINT
,
[
tf
.
float32
,
tf
.
uint8
])
print
(
recv
)
for
truth
in
DATA
:
arr
=
sess
.
run
(
recv
)
assert
(
arr
[
0
]
==
truth
[
0
])
.
all
()
assert
(
arr
[
1
]
==
truth
[
1
])
.
all
()
p
=
mp
.
Process
(
target
=
send
)
ensure_proc_terminate
(
p
)
start_proc_mask_signal
(
p
)
recv
()
p
.
join
()
tensorpack/user_ops/zmq_conn.h
View file @
a0d60a64
...
...
@@ -32,9 +32,8 @@ struct RecvTensorList {
class
ZMQConnection
{
public:
ZMQConnection
(
std
::
string
endpoint
,
int
zmq_socket_type
)
:
ZMQConnection
(
std
::
string
endpoint
,
int
zmq_socket_type
,
int
hwm
)
:
ctx_
(
1
),
sock_
(
ctx_
,
zmq_socket_type
)
{
int
hwm
=
100
;
// TODO make it an option
sock_
.
setsockopt
(
ZMQ_RCVHWM
,
&
hwm
,
sizeof
hwm
);
sock_
.
bind
(
endpoint
.
c_str
());
}
...
...
tensorpack/user_ops/zmq_recv.py
View file @
a0d60a64
...
...
@@ -13,7 +13,7 @@ from tensorflow.core.framework import types_pb2 as DataType
from
.common
import
compile
,
get_ext_suffix
__all__
=
[
'zmq_recv'
,
'dumps_
for_tf
op'
,
__all__
=
[
'zmq_recv'
,
'dumps_
zmq_
op'
,
'dump_tensor_protos'
,
'to_tensor_proto'
]
...
...
@@ -26,7 +26,7 @@ def build():
else
:
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
recv_mod
=
tf
.
load_op_library
(
os
.
path
.
join
(
file_dir
,
'zmq_recv_op
.
'
+
get_ext_suffix
()))
os
.
path
.
join
(
file_dir
,
'zmq_recv_op'
+
get_ext_suffix
()))
zmq_recv
=
recv_mod
.
zmq_recv
...
...
@@ -51,6 +51,7 @@ def to_tensor_proto(arr):
Args:
arr: numpy.ndarray. only supports common numerical types
"""
assert
isinstance
(
arr
,
np
.
ndarray
),
type
(
arr
)
dtype
=
_DTYPE_DICT
[
arr
.
dtype
]
ret
=
TensorProto
()
...
...
@@ -100,9 +101,15 @@ def dump_tensor_protos(protos):
return
s
def
dumps_
for_tf
op
(
dp
):
def
dumps_
zmq_
op
(
dp
):
"""
Dump a datapoint (list of nparray) into a format that the ZMQRecv op in tensorpack would accept.
Args:
dp: list of nparray
Returns:
a binary string
"""
protos
=
[
to_tensor_proto
(
arr
)
for
arr
in
dp
]
return
dump_tensor_protos
(
protos
)
tensorpack/user_ops/zmq_recv_op.cc
View file @
a0d60a64
...
...
@@ -16,11 +16,12 @@ REGISTER_OP("ZMQRecv")
.
Output
(
"output: types"
)
.
Attr
(
"end_point: string"
)
.
Attr
(
"types: list(type) >= 1"
)
.
Attr
(
"hwm: int >= 1 = 100"
)
.
SetShapeFn
(
shape_inference
::
UnknownShape
)
.
SetIsStateful
()
.
Doc
(
R"doc(
Receive a
serialized
list of Tensors from a ZMQ socket.
The serialization format is a tensorpack custom format.
Receive a list of Tensors from a ZMQ socket.
The serialization format is a tensorpack custom format
, defined in 'zmq_recv.py'
.
)doc"
);
...
...
@@ -32,7 +33,10 @@ class ZMQRecvOp: public OpKernel {
string
endpoint
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"end_point"
,
&
endpoint
));
conn_
.
reset
(
new
ZMQConnection
(
endpoint
,
ZMQ_PULL
));
int
hwm
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"hwm"
,
&
hwm
));
conn_
.
reset
(
new
ZMQConnection
(
endpoint
,
ZMQ_PULL
,
hwm
));
}
void
Compute
(
OpKernelContext
*
ctx
)
override
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment