Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2da6f9ed
Commit
2da6f9ed
authored
Dec 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ZMQ] correct so name; use int64; support scalar; (#362)
parent
65c8b239
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
73 additions
and
47 deletions
+73
-47
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+4
-2
tensorpack/user_ops/Makefile
tensorpack/user_ops/Makefile
+4
-4
tensorpack/user_ops/common.py
tensorpack/user_ops/common.py
+13
-11
tensorpack/user_ops/test-recv-op.py
tensorpack/user_ops/test-recv-op.py
+12
-5
tensorpack/user_ops/zmq_conn.h
tensorpack/user_ops/zmq_conn.h
+12
-5
tensorpack/user_ops/zmq_recv.py
tensorpack/user_ops/zmq_recv.py
+21
-14
tensorpack/user_ops/zmq_recv_op.cc
tensorpack/user_ops/zmq_recv_op.cc
+7
-6
No files found.
tensorpack/input_source/input_source.py
View file @
2da6f9ed
...
@@ -384,13 +384,15 @@ class ZMQInput(TensorInput):
...
@@ -384,13 +384,15 @@ class ZMQInput(TensorInput):
"""
"""
Not well implemented yet. Don't use.
Not well implemented yet. Don't use.
"""
"""
def
__init__
(
self
,
endpoint
):
def
__init__
(
self
,
endpoint
,
hwm
):
self
.
_endpoint
=
endpoint
self
.
_endpoint
=
endpoint
from
tensorpack.user_ops
import
zmq_recv
from
tensorpack.user_ops
import
zmq_recv
def
fn
():
def
fn
():
ret
=
zmq_recv
(
self
.
_endpoint
,
[
x
.
dtype
for
x
in
self
.
inputs_desc
])
ret
=
zmq_recv
(
self
.
_endpoint
,
[
x
.
dtype
for
x
in
self
.
inputs_desc
],
hwm
=
hwm
)
if
isinstance
(
ret
,
tf
.
Tensor
):
if
isinstance
(
ret
,
tf
.
Tensor
):
ret
=
[
ret
]
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
inputs_desc
)
assert
len
(
ret
)
==
len
(
self
.
inputs_desc
)
...
...
tensorpack/user_ops/Makefile
View file @
2da6f9ed
# $File: Makefile
# $File: Makefile
# $Date: Tue Dec 12
18:04:22
2017 -0800
# $Date: Tue Dec 12
22:27:38
2017 -0800
OBJ_DIR
=
obj
OBJ_DIR
=
obj
PYTHON
=
python
PYTHON
=
python
...
@@ -45,8 +45,8 @@ ccSOURCES = $(shell find $(SRCDIRS) -name "*.cc" | sed 's/^\.\///g')
...
@@ -45,8 +45,8 @@ ccSOURCES = $(shell find $(SRCDIRS) -name "*.cc" | sed 's/^\.\///g')
OBJS
=
$(
addprefix
$(OBJ_DIR)
/,
$(ccSOURCES:.cc=.o)
)
OBJS
=
$(
addprefix
$(OBJ_DIR)
/,
$(ccSOURCES:.cc=.o)
)
DEPFILES
=
$(OBJS:.o=.d)
DEPFILES
=
$(OBJS:.o=.d)
# TODO what about mac?
EXT_SUFFIX
?=
$(
shell
$(PYTHON)
-c
'import sysconfig; print(sysconfig.get_config_var("EXT_SUFFIX"
))
'
)
SO
=
zmq_recv_op
.so
SO
=
zmq_recv_op
$(EXT_SUFFIX)
.PHONY
:
all clean
.PHONY
:
all clean
...
@@ -56,7 +56,7 @@ ifneq ($(MAKECMDGOALS), clean)
...
@@ -56,7 +56,7 @@ ifneq ($(MAKECMDGOALS), clean)
sinclude
$(DEPFILES)
sinclude
$(DEPFILES)
endif
endif
%
.so
:
$(OBJ_DIR)/%.o
%
$(EXT_SUFFIX)
:
$(OBJ_DIR)/%.o
@
echo
"Linking
$@
..."
@
echo
"Linking
$@
..."
@
$(CXX)
$^
-o
$@
$(LDFLAGS)
@
$(CXX)
$^
-o
$@
$(LDFLAGS)
@
echo
"done."
@
echo
"done."
...
...
tensorpack/user_ops/common.py
View file @
2da6f9ed
...
@@ -2,26 +2,16 @@
...
@@ -2,26 +2,16 @@
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: common.py
# File: common.py
from
__future__
import
print_function
import
sysconfig
import
sysconfig
import
tensorflow
as
tf
import
tensorflow
as
tf
import
os
import
os
from
..utils
import
logger
def
compile
():
cxxflags
=
' '
.
join
(
tf
.
sysconfig
.
get_compile_flags
())
ldflags
=
' '
.
join
(
tf
.
sysconfig
.
get_link_flags
())
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
compile_cmd
=
'TF_CXXFLAGS="{}" TF_LDFLAGS="{}" make -C "{}"'
.
format
(
cxxflags
,
ldflags
,
file_dir
)
ret
=
os
.
system
(
compile_cmd
)
return
ret
# https://github.com/uber/horovod/blob/10835d25eccf4b198a23a0795edddf0896f6563d/horovod/tensorflow/mpi_ops.py#L30-L40
# https://github.com/uber/horovod/blob/10835d25eccf4b198a23a0795edddf0896f6563d/horovod/tensorflow/mpi_ops.py#L30-L40
def
get_ext_suffix
():
def
get_ext_suffix
():
"""Determine library extension for various versions of Python."""
"""Determine library extension for various versions of Python."""
return
'.so'
# TODO
ext_suffix
=
sysconfig
.
get_config_var
(
'EXT_SUFFIX'
)
ext_suffix
=
sysconfig
.
get_config_var
(
'EXT_SUFFIX'
)
if
ext_suffix
:
if
ext_suffix
:
return
ext_suffix
return
ext_suffix
...
@@ -33,5 +23,17 @@ def get_ext_suffix():
...
@@ -33,5 +23,17 @@ def get_ext_suffix():
return
'.so'
return
'.so'
def
compile
():
cxxflags
=
' '
.
join
(
tf
.
sysconfig
.
get_compile_flags
())
ldflags
=
' '
.
join
(
tf
.
sysconfig
.
get_link_flags
())
ext_suffix
=
get_ext_suffix
()
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
compile_cmd
=
'TF_CXXFLAGS="{}" TF_LDFLAGS="{}" EXT_SUFFIX="{}" make -C "{}"'
.
format
(
cxxflags
,
ldflags
,
ext_suffix
,
file_dir
)
logger
.
info
(
"Compile user_ops by command "
+
compile_cmd
+
' ...'
)
ret
=
os
.
system
(
compile_cmd
)
return
ret
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
compile
()
compile
()
tensorpack/user_ops/test-recv-op.py
View file @
2da6f9ed
...
@@ -38,11 +38,18 @@ def random_array(num):
...
@@ -38,11 +38,18 @@ def random_array(num):
ret
=
[]
ret
=
[]
for
k
in
range
(
num
):
for
k
in
range
(
num
):
arr1
=
np
.
random
.
rand
(
k
+
10
,
k
+
10
)
.
astype
(
'float32'
)
arr1
=
np
.
random
.
rand
(
k
+
10
,
k
+
10
)
.
astype
(
'float32'
)
# arr1 = 3.0
arr2
=
(
np
.
random
.
rand
((
k
+
10
)
*
2
)
*
10
)
.
astype
(
'uint8'
)
arr2
=
(
np
.
random
.
rand
((
k
+
10
)
*
2
)
*
10
)
.
astype
(
'uint8'
)
ret
.
append
([
arr1
,
arr2
])
ret
.
append
([
arr1
,
arr2
])
return
ret
return
ret
def
constant_array
(
num
):
arr
=
np
.
ones
((
30
,
30
))
.
astype
(
'float32'
)
arr2
=
np
.
ones
((
3
,
3
))
.
astype
(
'uint8'
)
return
[[
arr
,
arr2
]]
*
num
def
hash_dp
(
dp
):
def
hash_dp
(
dp
):
return
sum
([
k
.
sum
()
for
k
in
dp
])
return
sum
([
k
.
sum
()
for
k
in
dp
])
...
@@ -50,7 +57,7 @@ def hash_dp(dp):
...
@@ -50,7 +57,7 @@ def hash_dp(dp):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--task'
,
default
=
'basic'
,
parser
.
add_argument
(
'--task'
,
default
=
'basic'
,
choices
=
[
'basic'
,
'tworecv'
])
choices
=
[
'basic'
,
'tworecv'
,
'send'
])
parser
.
add_argument
(
'-n'
,
'--num'
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
'-n'
,
'--num'
,
type
=
int
,
default
=
10
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -68,10 +75,10 @@ if __name__ == '__main__':
...
@@ -68,10 +75,10 @@ if __name__ == '__main__':
arr
=
sess
.
run
(
recv
)
arr
=
sess
.
run
(
recv
)
assert
(
arr
[
0
]
==
truth
[
0
])
.
all
()
assert
(
arr
[
0
]
==
truth
[
0
])
.
all
()
assert
(
arr
[
1
]
==
truth
[
1
])
.
all
()
assert
(
arr
[
1
]
==
truth
[
1
])
.
all
()
elif
args
.
task
==
'send'
:
p
.
join
(
)
DATA
=
random_array
(
args
.
num
)
send
(
DATA
)
if
args
.
task
==
'tworecv'
:
el
if
args
.
task
==
'tworecv'
:
DATA
=
random_array
(
args
.
num
)
DATA
=
random_array
(
args
.
num
)
hashes
=
[
hash_dp
(
dp
)
for
dp
in
DATA
]
hashes
=
[
hash_dp
(
dp
)
for
dp
in
DATA
]
print
(
hashes
)
print
(
hashes
)
...
...
tensorpack/user_ops/zmq_conn.h
View file @
2da6f9ed
...
@@ -16,6 +16,12 @@ inline int read_int32(char** p) {
...
@@ -16,6 +16,12 @@ inline int read_int32(char** p) {
*
p
+=
4
;
*
p
+=
4
;
return
*
pi
;
return
*
pi
;
}
}
inline
tensorflow
::
int64
read_int64
(
char
**
p
)
{
auto
pi
=
reinterpret_cast
<
const
long
long
*>
(
*
p
);
*
p
+=
8
;
return
*
pi
;
}
}
}
namespace
tensorpack
{
namespace
tensorpack
{
...
@@ -26,7 +32,7 @@ struct RecvTensorList {
...
@@ -26,7 +32,7 @@ struct RecvTensorList {
struct
TensorConstructor
{
struct
TensorConstructor
{
tensorflow
::
DataType
dtype
;
tensorflow
::
DataType
dtype
;
tensorflow
::
TensorShape
shape
;
tensorflow
::
TensorShape
shape
;
int
size
;
// TODO bufsize
tensorflow
::
int64
buf_size
;
char
*
buf
;
char
*
buf
;
};
};
...
@@ -46,8 +52,9 @@ class ZMQConnection {
...
@@ -46,8 +52,9 @@ class ZMQConnection {
// https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels
// https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels
// zmq socket is not thread safe
// zmq socket is not thread safe
tensorflow
::
mutex_lock
lk
(
mu_
);
tensorflow
::
mutex_lock
lk
(
mu_
);
bool
succ
=
sock_
.
recv
(
&
tlist
->
message
);
// TODO this may throw
bool
succ
=
sock_
.
recv
(
&
tlist
->
message
);
// block until some data appears
// possible error code: http://api.zeromq.org/3-3:zmq-msg-recv
// TODO this may throw, handle exception?
// Possible error code: http://api.zeromq.org/3-3:zmq-msg-recv
// succ=false only if EAGAIN
// succ=false only if EAGAIN
CHECK
(
succ
);
// no EAGAIN, because we are blocking
CHECK
(
succ
);
// no EAGAIN, because we are blocking
}
}
...
@@ -68,9 +75,9 @@ class ZMQConnection {
...
@@ -68,9 +75,9 @@ class ZMQConnection {
int
shp
=
read_int32
(
&
pos
);
int
shp
=
read_int32
(
&
pos
);
tensors
[
i
].
shape
.
AddDim
(
shp
);
tensors
[
i
].
shape
.
AddDim
(
shp
);
}
}
int
sz
=
read_int32
(
&
pos
);
tensorflow
::
int64
sz
=
read_int64
(
&
pos
);
tensors
[
i
].
buf
=
pos
;
tensors
[
i
].
buf
=
pos
;
tensors
[
i
].
size
=
sz
;
tensors
[
i
].
buf_
size
=
sz
;
pos
+=
sz
;
pos
+=
sz
;
}
}
}
}
...
...
tensorpack/user_ops/zmq_recv.py
View file @
2da6f9ed
...
@@ -17,17 +17,18 @@ __all__ = ['zmq_recv', 'dumps_zmq_op',
...
@@ -17,17 +17,18 @@ __all__ = ['zmq_recv', 'dumps_zmq_op',
'dump_tensor_protos'
,
'to_tensor_proto'
]
'dump_tensor_protos'
,
'to_tensor_proto'
]
# TODO '.so' for linux only
def
build
():
def
build
():
global
zmq_recv
global
zmq_recv
ret
=
compile
()
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
if
ret
!=
0
:
basename
=
'zmq_recv_op'
+
get_ext_suffix
()
zmq_recv
=
None
so_file
=
os
.
path
.
join
(
file_dir
,
basename
)
else
:
if
not
os
.
path
.
isfile
(
so_file
):
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
ret
=
compile
()
recv_mod
=
tf
.
load_op_library
(
if
ret
!=
0
:
os
.
path
.
join
(
file_dir
,
'zmq_recv_op'
+
get_ext_suffix
()))
raise
RuntimeError
(
"tensorpack user_ops compilation failed!"
)
zmq_recv
=
recv_mod
.
zmq_recv
recv_mod
=
tf
.
load_op_library
(
so_file
)
zmq_recv
=
recv_mod
.
zmq_recv
build
()
build
()
...
@@ -43,7 +44,6 @@ _DTYPE_DICT = {
...
@@ -43,7 +44,6 @@ _DTYPE_DICT = {
_DTYPE_DICT
=
{
np
.
dtype
(
k
):
v
for
k
,
v
in
_DTYPE_DICT
.
items
()}
_DTYPE_DICT
=
{
np
.
dtype
(
k
):
v
for
k
,
v
in
_DTYPE_DICT
.
items
()}
# TODO support string tensor and scalar
def
to_tensor_proto
(
arr
):
def
to_tensor_proto
(
arr
):
"""
"""
Convert a numpy array to TensorProto
Convert a numpy array to TensorProto
...
@@ -51,8 +51,15 @@ def to_tensor_proto(arr):
...
@@ -51,8 +51,15 @@ def to_tensor_proto(arr):
Args:
Args:
arr: numpy.ndarray. only supports common numerical types
arr: numpy.ndarray. only supports common numerical types
"""
"""
if
isinstance
(
arr
,
float
):
arr
=
np
.
asarray
(
arr
)
.
astype
(
'float32'
)
elif
isinstance
(
arr
,
int
):
arr
=
np
.
asarray
(
arr
)
.
astype
(
'int32'
)
assert
isinstance
(
arr
,
np
.
ndarray
),
type
(
arr
)
assert
isinstance
(
arr
,
np
.
ndarray
),
type
(
arr
)
dtype
=
_DTYPE_DICT
[
arr
.
dtype
]
try
:
dtype
=
_DTYPE_DICT
[
arr
.
dtype
]
except
KeyError
:
raise
KeyError
(
"Dtype {} is unsupported by current ZMQ Op!"
.
format
(
arr
.
dtype
))
ret
=
TensorProto
()
ret
=
TensorProto
()
shape
=
ret
.
tensor_shape
shape
=
ret
.
tensor_shape
...
@@ -83,9 +90,8 @@ def dump_tensor_protos(protos):
...
@@ -83,9 +90,8 @@ def dump_tensor_protos(protos):
Where each tensor is:
Where each tensor is:
[dtype(int32)][ndims(int32)][shape[0](int32)]...[shape[n](int32)]
[dtype(int32)][ndims(int32)][shape[0](int32)]...[shape[n](int32)]
[len(buffer)(int
32
)][buffer]
[len(buffer)(int
64
)][buffer]
"""
"""
# TODO use int64
s
=
struct
.
pack
(
'=i'
,
len
(
protos
))
s
=
struct
.
pack
(
'=i'
,
len
(
protos
))
for
p
in
protos
:
for
p
in
protos
:
...
@@ -96,7 +102,7 @@ def dump_tensor_protos(protos):
...
@@ -96,7 +102,7 @@ def dump_tensor_protos(protos):
s
+=
struct
.
pack
(
'=i'
,
len
(
dims
))
s
+=
struct
.
pack
(
'=i'
,
len
(
dims
))
for
k
in
dims
:
for
k
in
dims
:
s
+=
struct
.
pack
(
'=i'
,
k
.
size
)
s
+=
struct
.
pack
(
'=i'
,
k
.
size
)
s
+=
struct
.
pack
(
'=
i'
,
len
(
tensor_content
))
# won't send stuff over 2G
s
+=
struct
.
pack
(
'=
q'
,
len
(
tensor_content
))
s
+=
tensor_content
s
+=
tensor_content
return
s
return
s
...
@@ -111,5 +117,6 @@ def dumps_zmq_op(dp):
...
@@ -111,5 +117,6 @@ def dumps_zmq_op(dp):
Returns:
Returns:
a binary string
a binary string
"""
"""
assert
isinstance
(
dp
,
(
list
,
tuple
))
protos
=
[
to_tensor_proto
(
arr
)
for
arr
in
dp
]
protos
=
[
to_tensor_proto
(
arr
)
for
arr
in
dp
]
return
dump_tensor_protos
(
protos
)
return
dump_tensor_protos
(
protos
)
tensorpack/user_ops/zmq_recv_op.cc
View file @
2da6f9ed
...
@@ -27,7 +27,6 @@ The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'
...
@@ -27,7 +27,6 @@ The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'
namespace
tensorpack
{
namespace
tensorpack
{
class
ZMQRecvOp
:
public
AsyncOpKernel
{
class
ZMQRecvOp
:
public
AsyncOpKernel
{
public:
public:
explicit
ZMQRecvOp
(
OpKernelConstruction
*
context
)
:
AsyncOpKernel
(
context
)
{
explicit
ZMQRecvOp
(
OpKernelConstruction
*
context
)
:
AsyncOpKernel
(
context
)
{
...
@@ -39,6 +38,7 @@ class ZMQRecvOp: public AsyncOpKernel {
...
@@ -39,6 +38,7 @@ class ZMQRecvOp: public AsyncOpKernel {
int
hwm
;
int
hwm
;
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"hwm"
,
&
hwm
));
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"hwm"
,
&
hwm
));
// will get called only at the first sess.run call
conn_
.
reset
(
new
ZMQConnection
(
endpoint
,
ZMQ_PULL
,
hwm
));
conn_
.
reset
(
new
ZMQConnection
(
endpoint
,
ZMQ_PULL
,
hwm
));
}
}
...
@@ -61,15 +61,16 @@ class ZMQRecvOp: public AsyncOpKernel {
...
@@ -61,15 +61,16 @@ class ZMQRecvOp: public AsyncOpKernel {
auto
recv_dtype
=
tensors
[
j
].
dtype
;
auto
recv_dtype
=
tensors
[
j
].
dtype
;
OP_REQUIRES_ASYNC
(
OP_REQUIRES_ASYNC
(
ctx
,
component_types_
[
j
]
==
recv_dtype
,
ctx
,
component_types_
[
j
]
==
recv_dtype
,
errors
::
InvalidArgument
(
"Type mismatch between parsed tensor ("
,
errors
::
InvalidArgument
(
"Type mismatch at index "
,
std
::
to_string
(
j
),
DataTypeString
(
recv_dtype
),
") and dtype ("
,
" between received tensor ("
,
DataTypeString
(
recv_dtype
),
DataTypeString
(
component_types_
[
j
]),
")"
),
done
);
") and dtype ("
,
DataTypeString
(
component_types_
[
j
]),
")"
),
done
);
TensorShape
&
shape
=
tensors
[
j
].
shape
;
TensorShape
&
shape
=
tensors
[
j
].
shape
;
OP_REQUIRES_OK_ASYNC
(
ctx
,
ctx
->
allocate_output
(
i
,
shape
,
&
output
),
done
);
OP_REQUIRES_OK_ASYNC
(
ctx
,
ctx
->
allocate_output
(
i
,
shape
,
&
output
),
done
);
auto
ptr
=
output
->
bit_casted_shaped
<
char
,
1
>
({
shape
.
num_elements
()});
auto
ptr
=
output
->
bit_casted_shaped
<
char
,
1
>
({
shape
.
num_elements
()})
.
data
()
;
memcpy
(
ptr
.
data
(),
tensors
[
j
].
buf
,
tensors
[
j
].
size
);
memcpy
(
ptr
,
tensors
[
j
].
buf
,
tensors
[
j
].
buf_
size
);
outputs
.
set
(
j
,
*
output
);
outputs
.
set
(
j
,
*
output
);
}
}
done
();
done
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment