Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0c67af01
Commit
0c67af01
authored
Jan 06, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Move zmq ops to a separate project
parent
b20f615d
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
4 additions
and
643 deletions
+4
-643
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+2
-2
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+2
-2
tensorpack/user_ops/.ycm_extra_conf.py
tensorpack/user_ops/.ycm_extra_conf.py
+0
-23
tensorpack/user_ops/Makefile
tensorpack/user_ops/Makefile
+0
-74
tensorpack/user_ops/__init__.py
tensorpack/user_ops/__init__.py
+0
-2
tensorpack/user_ops/common.py
tensorpack/user_ops/common.py
+0
-39
tensorpack/user_ops/test-pull-op.py
tensorpack/user_ops/test-pull-op.py
+0
-99
tensorpack/user_ops/zmq_conn.h
tensorpack/user_ops/zmq_conn.h
+0
-119
tensorpack/user_ops/zmq_ops.cc
tensorpack/user_ops/zmq_ops.cc
+0
-123
tensorpack/user_ops/zmq_ops.py
tensorpack/user_ops/zmq_ops.py
+0
-160
No files found.
tensorpack/dataflow/remote.py
View file @
0c67af01
...
...
@@ -33,13 +33,13 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None):
hwm (int): ZMQ high-water mark (buffer size)
format (str): The serialization format.
Default format would use :mod:`tensorpack.utils.serialize` (i.e. msgpack).
An alternate format is 'zmq_op'.
An alternate format is 'zmq_op'
, used by https://github.com/tensorpack/zmq_ops
.
"""
assert
format
in
[
None
,
'zmq_op'
]
if
format
is
None
:
dump_fn
=
dumps
else
:
from
..user_ops.zmq_recv
import
dumps_zmq_op
from
zmq_ops
import
dumps_zmq_op
dump_fn
=
dumps_zmq_op
ctx
=
zmq
.
Context
()
...
...
tensorpack/input_source/input_source.py
View file @
0c67af01
...
...
@@ -370,7 +370,7 @@ class DummyConstantInput(TensorInput):
class
ZMQInput
(
TensorInput
):
"""
Recv tensors from a ZMQ endpoint.
Recv tensors from a ZMQ endpoint
, with ops from https://github.com/tensorpack/zmq_ops
.
It works with :meth:`dataflow.remote.send_dataflow_zmq(format='zmq_op')`.
"""
def
__init__
(
self
,
end_point
,
hwm
):
...
...
@@ -395,7 +395,7 @@ class ZMQInput(TensorInput):
"ZMQInput has to be used with InputDesc!"
self
.
_desc
=
inputs_desc
from
..user_ops
import
zmq_ops
import
zmq_ops
self
.
_zmq_pull_socket
=
zmq_ops
.
ZMQPullSocket
(
self
.
_end_point
,
[
x
.
type
for
x
in
inputs_desc
],
...
...
tensorpack/user_ops/.ycm_extra_conf.py
deleted
100644 → 0
View file @
b20f615d
import
tensorflow
as
tf
flags
=
[
'-Wall'
,
'-Wextra'
,
'-Werror'
,
'-Wno-long-long'
,
'-Wno-variadic-macros'
,
'-fexceptions'
,
'-std=c++11'
,
'-x'
,
'c++'
,
'-isystem'
,
tf
.
sysconfig
.
get_include
()
]
def
FlagsForFile
(
filename
,
**
kwargs
):
return
{
'flags'
:
flags
,
'do_cache'
:
True
}
tensorpack/user_ops/Makefile
deleted
100644 → 0
View file @
b20f615d
# $File: Makefile
# $Date: Thu Dec 21 14:12:30 2017 -0800
OBJ_DIR
=
obj
PYTHON
=
python
UNAME_S
:=
$(
shell
uname
-s
)
ifeq
($(UNAME_S),Linux)
CXX
?=
g++
endif
ifeq
($(UNAME_S),Darwin)
CXX
?=
clang++
endif
OPTFLAGS
?=
-O3
-march
=
native
#OPTFLAGS ?= -g3 -fsanitize=address,undefined -O2 -lasan
#OPTFLAGS ?= -g3 -fsanitize=leak -O2 -lubsan
# libraries: TF preceeds others, so g++ looks for protobuf among TF headers
ifneq
($(MAKECMDGOALS), clean)
TF_CXXFLAGS
?=
$(
shell
$(PYTHON)
-c
'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags(
)))
'
)
TF_LDFLAGS
?=
$(
shell
$(PYTHON)
-c
'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags(
)))
'
)
endif
CXXFLAGS
+=
$(TF_CXXFLAGS)
LDFLAGS
+=
$(TF_LDFLAGS)
# extra packages from pkg-config
LIBS
=
libzmq
CXXFLAGS
+=
$(
shell
pkg-config
--cflags
$(LIBS)
)
LDFLAGS
+=
$(
shell
pkg-config
$(LIBS)
--libs
)
CXXFLAGS
+=
-Wall
-Wextra
-Wno-unused-parameter
-Wno-sign-compare
CXXFLAGS
+=
$(DEFINES)
-std
=
c++11
$(OPTFLAGS)
-fPIC
LDFLAGS
+=
$(OPTFLAGS)
LDFLAGS
+=
-shared
-fPIC
ifeq
($(UNAME_S),Darwin)
LDFLAGS
+=
-Wl
,-undefined
-Wl
,dynamic_lookup
endif
SHELL
=
bash
# sources to include
ccSOURCES
=
$(
shell
find
$(SRCDIRS)
-name
"*.cc"
|
sed
's/^\.\///g'
)
OBJS
=
$(
addprefix
$(OBJ_DIR)
/,
$(ccSOURCES:.cc=.o)
)
DEPFILES
=
$(OBJS:.o=.d)
EXT_SUFFIX
?=
$(
shell
$(PYTHON)
-c
'import sysconfig; print(sysconfig.get_config_var("EXT_SUFFIX"
))
'
)
SO
=
zmq_ops
$(EXT_SUFFIX)
.PHONY
:
all clean
all
:
$(SO)
ifneq
($(MAKECMDGOALS), clean)
sinclude
$(DEPFILES)
endif
%$(EXT_SUFFIX)
:
$(OBJ_DIR)/%.o
@
echo
"Linking
$@
..."
@
$(CXX)
$^
-o
$@
$(LDFLAGS)
@
echo
"done."
$(OBJ_DIR)/%.o
:
%.cc
@
echo
"[cc]
$<
..."
@
$(CXX)
-c
$<
-o
$@
$(CXXFLAGS)
$(OBJ_DIR)/%.d
:
%.cc Makefile
@
mkdir
-pv
$(
dir
$@
)
@
echo
"[dep]
$<
..."
@
$(CXX)
$(CXXFLAGS)
-MM
-MT
"
$(OBJ_DIR)
/
$
(<:.cc=.o)
$(OBJ_DIR)
/
$
(<:.cc=.d)"
"
$<
"
>
"
$@
"
||
rm
"
$@
"
clean
:
@
rm
-rvf
$(OBJ_DIR)
$(SO)
tensorpack/user_ops/__init__.py
deleted
100644 → 0
View file @
b20f615d
#!/usr/bin/env python
# -*- coding: utf-8 -*-
tensorpack/user_ops/common.py
deleted
100644 → 0
View file @
b20f615d
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
import
sysconfig
import
tensorflow
as
tf
import
os
from
..utils
import
logger
# https://github.com/uber/horovod/blob/10835d25eccf4b198a23a0795edddf0896f6563d/horovod/tensorflow/mpi_ops.py#L30-L40
def
get_ext_suffix
():
"""Determine library extension for various versions of Python."""
ext_suffix
=
sysconfig
.
get_config_var
(
'EXT_SUFFIX'
)
if
ext_suffix
:
return
ext_suffix
ext_suffix
=
sysconfig
.
get_config_var
(
'SO'
)
if
ext_suffix
:
return
ext_suffix
return
'.so'
def
compile
():
cxxflags
=
' '
.
join
(
tf
.
sysconfig
.
get_compile_flags
())
ldflags
=
' '
.
join
(
tf
.
sysconfig
.
get_link_flags
())
ext_suffix
=
get_ext_suffix
()
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
compile_cmd
=
'TF_CXXFLAGS="{}" TF_LDFLAGS="{}" EXT_SUFFIX="{}" make -C "{}"'
.
format
(
cxxflags
,
ldflags
,
ext_suffix
,
file_dir
)
logger
.
info
(
"Compile user_ops by command "
+
compile_cmd
+
' ...'
)
ret
=
os
.
system
(
compile_cmd
)
return
ret
if
__name__
==
'__main__'
:
compile
()
tensorpack/user_ops/test-pull-op.py
deleted
100644 → 0
View file @
b20f615d
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: test-pull-op.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
import
zmq
import
argparse
import
multiprocessing
as
mp
import
time
import
numpy
as
np
os
.
environ
[
'TF_CPP_MIN_LOG_LEVEL'
]
=
'2'
import
tensorflow
as
tf
# noqa
from
tensorpack.user_ops.zmq_ops
import
(
# noqa
ZMQPullSocket
,
dumps_zmq_op
)
from
tensorpack.utils.concurrency
import
(
# noqa
start_proc_mask_signal
,
ensure_proc_terminate
)
ENDPOINT
=
'ipc://test-pipe'
def
send
(
iterable
,
delay
=
0
):
ctx
=
zmq
.
Context
()
sok
=
ctx
.
socket
(
zmq
.
PUSH
)
sok
.
connect
(
ENDPOINT
)
for
dp
in
iterable
:
if
delay
>
0
:
time
.
sleep
(
delay
)
print
(
"Sending data to socket.."
)
sok
.
send
(
dumps_zmq_op
(
dp
))
time
.
sleep
(
999
)
def
random_array
(
num
):
ret
=
[]
for
k
in
range
(
num
):
arr1
=
np
.
random
.
rand
(
k
+
10
,
k
+
10
)
.
astype
(
'float32'
)
# arr1 = 3.0
arr2
=
(
np
.
random
.
rand
((
k
+
10
)
*
2
)
*
10
)
.
astype
(
'uint8'
)
ret
.
append
([
arr1
,
arr2
])
return
ret
def
constant_array
(
num
):
arr
=
np
.
ones
((
30
,
30
))
.
astype
(
'float32'
)
arr2
=
np
.
ones
((
3
,
3
))
.
astype
(
'uint8'
)
return
[[
arr
,
arr2
]]
*
num
def
hash_dp
(
dp
):
return
sum
([
k
.
sum
()
for
k
in
dp
])
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--task'
,
default
=
'basic'
,
choices
=
[
'basic'
,
'tworecv'
,
'send'
])
parser
.
add_argument
(
'-n'
,
'--num'
,
type
=
int
,
default
=
10
)
args
=
parser
.
parse_args
()
if
args
.
task
==
'basic'
:
DATA
=
random_array
(
args
.
num
)
p
=
mp
.
Process
(
target
=
send
,
args
=
(
DATA
,))
ensure_proc_terminate
(
p
)
start_proc_mask_signal
(
p
)
sess
=
tf
.
Session
()
recv
=
ZMQPullSocket
(
ENDPOINT
,
[
tf
.
float32
,
tf
.
uint8
])
.
pull
()
print
(
recv
)
for
truth
in
DATA
:
arr
=
sess
.
run
(
recv
)
assert
(
arr
[
0
]
==
truth
[
0
])
.
all
()
assert
(
arr
[
1
]
==
truth
[
1
])
.
all
()
elif
args
.
task
==
'send'
:
DATA
=
random_array
(
args
.
num
)
send
(
DATA
)
elif
args
.
task
==
'tworecv'
:
DATA
=
random_array
(
args
.
num
)
hashes
=
[
hash_dp
(
dp
)
for
dp
in
DATA
]
print
(
hashes
)
p
=
mp
.
Process
(
target
=
send
,
args
=
(
DATA
,
0.00
))
ensure_proc_terminate
(
p
)
start_proc_mask_signal
(
p
)
sess
=
tf
.
Session
()
zmqsock
=
ZMQPullSocket
(
ENDPOINT
,
[
tf
.
float32
,
tf
.
uint8
],
hwm
=
1
)
recv1
=
zmqsock
.
pull
()
recv2
=
zmqsock
.
pull
()
print
(
recv1
,
recv2
)
for
i
in
range
(
args
.
num
//
2
):
res1
,
res2
=
sess
.
run
([
recv1
,
recv2
])
h1
,
h2
=
hash_dp
(
res1
),
hash_dp
(
res2
)
print
(
"Recv "
,
i
,
h1
,
h2
)
assert
h1
in
hashes
and
h2
in
hashes
tensorpack/user_ops/zmq_conn.h
deleted
100644 → 0
View file @
b20f615d
//File: zmq_conn.h
//Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#pragma once
#include <string>
#include <iostream>
#include <thread>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/framework/tensor_shape.h>
#include <tensorflow/core/lib/gtl/inlined_vector.h>
#include <tensorflow/core/lib/strings/strcat.h>
#include <tensorflow/core/platform/mutex.h>
#include "zmq.hpp"
namespace
{
inline
int
read_int32
(
char
**
p
)
{
auto
pi
=
reinterpret_cast
<
const
int
*>
(
*
p
);
*
p
+=
4
;
return
*
pi
;
}
inline
tensorflow
::
int64
read_int64
(
char
**
p
)
{
auto
pi
=
reinterpret_cast
<
const
tensorflow
::
int64
*>
(
*
p
);
*
p
+=
8
;
return
*
pi
;
}
}
namespace
tensorpack
{
struct
ZMQSocketDef
{
std
::
string
end_point
;
int
socket_type
,
// ZMQ_PULL
hwm
;
bool
bind
;
// bind or connect
std
::
string
DebugString
()
const
{
return
tensorflow
::
strings
::
StrCat
(
"EndPoint="
,
end_point
,
", hwm="
,
std
::
to_string
(
hwm
));
}
};
struct
RecvTensorList
{
zmq
::
message_t
message
;
struct
TensorConstructor
{
tensorflow
::
DataType
dtype
;
tensorflow
::
TensorShape
shape
;
tensorflow
::
int64
buf_size
;
char
*
buf
;
};
tensorflow
::
gtl
::
InlinedVector
<
TensorConstructor
,
4
>
tensors
;
};
class
ZMQConnection
:
public
tensorflow
::
ResourceBase
{
public:
explicit
ZMQConnection
(
const
ZMQSocketDef
&
def
)
:
def_
{
def
},
ctx_
{
1
},
sock_
{
ctx_
,
def
.
socket_type
}
{
int
linger
=
0
;
sock_
.
setsockopt
(
ZMQ_LINGER
,
&
linger
,
sizeof
linger
);
sock_
.
setsockopt
(
ZMQ_RCVHWM
,
&
def
.
hwm
,
sizeof
def
.
hwm
);
if
(
def
.
bind
)
{
sock_
.
bind
(
def
.
end_point
.
c_str
());
}
else
{
sock_
.
connect
(
def
.
end_point
.
c_str
());
}
}
std
::
string
DebugString
()
override
{
return
def_
.
DebugString
();
}
void
recv_tensor_list
(
RecvTensorList
*
tlist
)
{
{
// https://www.tensorflow.org/extend/adding_an_op#multi-threaded_cpu_kernels
// zmq socket is not thread safe
tensorflow
::
mutex_lock
lk
(
mu_
);
bool
succ
=
sock_
.
recv
(
&
tlist
->
message
);
// block until some data appears
// TODO this may throw, handle exception?
// Possible error code: http://api.zeromq.org/3-3:zmq-msg-recv
// succ=false only if EAGAIN
CHECK
(
succ
);
// no EAGAIN, because we are blocking
}
char
*
pos
=
reinterpret_cast
<
char
*>
(
tlist
->
message
.
data
());
int
num
=
read_int32
(
&
pos
);
auto
&
tensors
=
tlist
->
tensors
;
tensors
.
resize
(
num
);
CHECK_LE
(
num
,
15
);
// probably a format error
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
dt
=
read_int32
(
&
pos
);
tensors
[
i
].
dtype
=
tensorflow
::
DataType
(
dt
);
int
ndim
=
read_int32
(
&
pos
);
CHECK_LE
(
ndim
,
8
);
// probably an error.
for
(
int
k
=
0
;
k
<
ndim
;
++
k
)
{
int
shp
=
read_int32
(
&
pos
);
tensors
[
i
].
shape
.
AddDim
(
shp
);
}
tensorflow
::
int64
sz
=
read_int64
(
&
pos
);
tensors
[
i
].
buf
=
pos
;
tensors
[
i
].
buf_size
=
sz
;
pos
+=
sz
;
}
}
const
ZMQSocketDef
&
get_socket_def
()
const
{
return
def_
;
}
private:
ZMQSocketDef
def_
;
tensorflow
::
mutex
mu_
;
zmq
::
context_t
ctx_
;
zmq
::
socket_t
sock_
;
};
}
// namespace tensorpack
tensorpack/user_ops/zmq_ops.cc
deleted
100644 → 0
View file @
b20f615d
//File: zmq_ops.cc
//Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#include <string>
#include <memory>
#include <tensorflow/core/framework/op.h>
#include <tensorflow/core/framework/op_kernel.h>
#include <tensorflow/core/framework/resource_op_kernel.h>
#include <tensorflow/core/framework/resource_mgr.h>
#include <tensorflow/core/framework/common_shape_fns.h>
#include "zmq_conn.h"
using
namespace
std
;
using
namespace
tensorflow
;
namespace
tensorpack
{
// An op to create zmq connection as a resource.
// Use ResourceOpKernel to ensure singleton construction.
class
ZMQConnectionHandleOp
:
public
ResourceOpKernel
<
ZMQConnection
>
{
public:
explicit
ZMQConnectionHandleOp
(
OpKernelConstruction
*
ctx
)
:
ResourceOpKernel
<
ZMQConnection
>
(
ctx
)
{}
private:
Status
CreateResource
(
ZMQConnection
**
ret
)
override
EXCLUSIVE_LOCKS_REQUIRED
(
mu_
)
{
const
NodeDef
&
ndef
=
def
();
ZMQSocketDef
sockdef
;
sockdef
.
socket_type
=
ZMQ_PULL
;
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
ndef
,
"bind"
,
&
sockdef
.
bind
));
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
ndef
,
"end_point"
,
&
sockdef
.
end_point
));
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
ndef
,
"hwm"
,
&
sockdef
.
hwm
));
*
ret
=
new
ZMQConnection
(
sockdef
);
return
Status
::
OK
();
}
// Can verify, but probably not necessary because python is not going to eval this op twice with
// the same shared name
};
class
ZMQPullOp
:
public
AsyncOpKernel
{
public:
explicit
ZMQPullOp
(
OpKernelConstruction
*
context
)
:
AsyncOpKernel
(
context
)
{
OP_REQUIRES_OK
(
context
,
context
->
GetAttr
(
"types"
,
&
component_types_
));
}
void
ComputeAsync
(
OpKernelContext
*
ctx
,
DoneCallback
done
)
override
{
ZMQConnection
*
conn
=
nullptr
;
OP_REQUIRES_OK_ASYNC
(
ctx
,
LookupResource
(
ctx
,
HandleFromInput
(
ctx
,
0
),
&
conn
),
done
);
RecvTensorList
tlist
;
conn
->
recv_tensor_list
(
&
tlist
);
auto
&
tensors
=
tlist
.
tensors
;
CHECK
(
tensors
.
size
()
==
num_components
());
for
(
int
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
Tensor
*
output
=
nullptr
;
auto
recv_dtype
=
tensors
[
i
].
dtype
;
OP_REQUIRES_ASYNC
(
ctx
,
component_types_
[
i
]
==
recv_dtype
,
errors
::
InvalidArgument
(
"Type mismatch at index "
,
std
::
to_string
(
i
),
" between received tensor ("
,
DataTypeString
(
recv_dtype
),
") and dtype ("
,
DataTypeString
(
component_types_
[
i
]),
")"
),
done
);
TensorShape
&
shape
=
tensors
[
i
].
shape
;
OP_REQUIRES_OK_ASYNC
(
ctx
,
ctx
->
allocate_output
(
i
,
shape
,
&
output
),
done
);
// reinterpret cast and then memcpy
auto
ptr
=
output
->
bit_casted_shaped
<
char
,
1
>
({
shape
.
num_elements
()}).
data
();
// {shape.num_elements() * DataTypeSize(recv_dtype)}).data();
memcpy
(
ptr
,
tensors
[
i
].
buf
,
tensors
[
i
].
buf_size
);
}
done
();
}
private:
DataTypeVector
component_types_
;
size_t
num_components
()
const
{
return
component_types_
.
size
();
}
};
REGISTER_KERNEL_BUILDER
(
Name
(
"ZMQPull"
).
Device
(
DEVICE_CPU
),
ZMQPullOp
);
REGISTER_KERNEL_BUILDER
(
Name
(
"ZMQConnection"
).
Device
(
DEVICE_CPU
),
ZMQConnectionHandleOp
);
}
// namespace tensorpack
REGISTER_OP
(
"ZMQPull"
)
.
Input
(
"handle: resource"
)
.
Output
(
"output: types"
)
.
Attr
(
"types: list(type) >= 1"
)
.
SetShapeFn
(
shape_inference
::
UnknownShape
)
.
SetIsStateful
()
.
Doc
(
R"doc(
Receive a list of Tensors from a ZMQ connection handle.
The serialization format is a tensorpack custom format, defined in 'zmq_recv.py'.
)doc"
);
REGISTER_OP
(
"ZMQConnection"
)
.
Output
(
"handle: resource"
)
.
Attr
(
"end_point: string"
)
.
Attr
(
"hwm: int >= 1 = 10"
)
.
Attr
(
"bind: bool = true"
)
.
Attr
(
"container: string = ''"
)
.
Attr
(
"shared_name: string = ''"
)
.
SetIsStateful
()
.
SetShapeFn
(
shape_inference
::
ScalarShape
)
.
Doc
(
R"doc(
Opens a ZMQ PULL socket and returns a handle to it as a resource.
end_point: the ZMQ end point.
hwm: ZMQ high-water mark.
bind: If false, will connect to the endpoint rather than bind to it.
container: required for a resource op kernel.
shared_name: If non-empty, this connection will be shared under the given name across multiple sessions.
)doc"
);
tensorpack/user_ops/zmq_ops.py
deleted
100644 → 0
View file @
b20f615d
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: zmq_pull.py
import
tensorflow
as
tf
import
struct
import
numpy
as
np
import
os
from
tensorflow.core.framework.tensor_pb2
import
TensorProto
from
tensorflow.core.framework
import
types_pb2
as
DT
# have to import like this: https://github.com/tensorflow/tensorflow/commit/955f038afbeb81302cea43058078e68574000bce
from
.common
import
compile
,
get_ext_suffix
__all__
=
[
'dumps_zmq_op'
,
'ZMQPullSocket'
]
_zmq_mod
=
None
def
try_build
():
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
basename
=
'zmq_ops'
+
get_ext_suffix
()
so_file
=
os
.
path
.
join
(
file_dir
,
basename
)
if
not
os
.
path
.
isfile
(
so_file
):
ret
=
compile
()
if
ret
!=
0
:
raise
RuntimeError
(
"tensorpack user_ops compilation failed!"
)
global
_zmq_mod
_zmq_mod
=
tf
.
load_op_library
(
so_file
)
try_build
()
class
ZMQPullSocket
(
object
):
def
__init__
(
self
,
end_point
,
types
,
hwm
=
None
,
bind
=
True
,
name
=
None
):
self
.
_types
=
types
assert
isinstance
(
bind
,
bool
),
bind
if
name
is
None
:
self
.
_name
=
(
tf
.
get_default_graph
()
.
unique_name
(
self
.
__class__
.
__name__
))
else
:
self
.
_name
=
name
self
.
_zmq_handle
=
_zmq_mod
.
zmq_connection
(
end_point
,
hwm
,
bind
=
bind
,
shared_name
=
self
.
_name
)
@
property
def
name
(
self
):
return
self
.
_name
def
pull
(
self
):
return
_zmq_mod
.
zmq_pull
(
self
.
_zmq_handle
,
self
.
_types
)
# copied from tensorflow/python/framework/dtypes.py
_DTYPE_DICT
=
{
np
.
float16
:
DT
.
DT_HALF
,
np
.
float32
:
DT
.
DT_FLOAT
,
np
.
float64
:
DT
.
DT_DOUBLE
,
np
.
uint8
:
DT
.
DT_UINT8
,
np
.
uint16
:
DT
.
DT_UINT16
,
np
.
uint32
:
DT
.
DT_UINT32
,
np
.
uint64
:
DT
.
DT_UINT64
,
np
.
int64
:
DT
.
DT_INT64
,
np
.
int32
:
DT
.
DT_INT32
,
np
.
int16
:
DT
.
DT_INT16
,
np
.
int8
:
DT
.
DT_INT8
,
np
.
complex64
:
DT
.
DT_COMPLEX64
,
np
.
complex128
:
DT
.
DT_COMPLEX128
,
np
.
bool
:
DT
.
DT_BOOL
,
}
_DTYPE_DICT
=
{
np
.
dtype
(
k
):
v
for
k
,
v
in
_DTYPE_DICT
.
items
()}
def
to_tensor_proto
(
arr
):
"""
Convert a numpy array to TensorProto
Args:
arr: numpy.ndarray. only supports common numerical types
"""
if
isinstance
(
arr
,
float
):
arr
=
np
.
asarray
(
arr
)
.
astype
(
'float32'
)
elif
isinstance
(
arr
,
int
):
arr
=
np
.
asarray
(
arr
)
.
astype
(
'int32'
)
assert
isinstance
(
arr
,
np
.
ndarray
),
type
(
arr
)
try
:
dtype
=
_DTYPE_DICT
[
arr
.
dtype
]
except
KeyError
:
raise
KeyError
(
"Dtype {} is unsupported by current ZMQ Op!"
.
format
(
arr
.
dtype
))
ret
=
TensorProto
()
shape
=
ret
.
tensor_shape
for
s
in
arr
.
shape
:
d
=
shape
.
dim
.
add
()
d
.
size
=
s
ret
.
dtype
=
dtype
buf
=
arr
.
tobytes
()
ret
.
tensor_content
=
buf
return
ret
def
dump_tensor_protos
(
protos
):
"""
Serialize a list of :class:`TensorProto`, for communication between custom TensorFlow ops.
Args:
protos (list): list of :class:`TensorProto` instance
Notes:
The format is:
[#tensors(int32)]
[tensor1][tensor2]...
Where each tensor is:
[dtype(int32)][ndims(int32)][shape[0](int32)]...[shape[n](int32)]
[len(buffer)(int64)][buffer]
"""
s
=
struct
.
pack
(
'=i'
,
len
(
protos
))
for
p
in
protos
:
tensor_content
=
p
.
tensor_content
s
+=
struct
.
pack
(
'=i'
,
int
(
p
.
dtype
))
dims
=
p
.
tensor_shape
.
dim
s
+=
struct
.
pack
(
'=i'
,
len
(
dims
))
for
k
in
dims
:
s
+=
struct
.
pack
(
'=i'
,
k
.
size
)
s
+=
struct
.
pack
(
'=q'
,
len
(
tensor_content
))
s
+=
tensor_content
return
s
def
dumps_zmq_op
(
dp
):
"""
Dump a datapoint (list of nparray) into a format that the ZMQPull op in tensorpack would accept.
Args:
dp: list of nparray
Returns:
a binary string
"""
assert
isinstance
(
dp
,
(
list
,
tuple
))
protos
=
[
to_tensor_proto
(
arr
)
for
arr
in
dp
]
return
dump_tensor_protos
(
protos
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment