Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b5c5a944
Commit
b5c5a944
authored
Apr 06, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ZMQInput can run.
parent
6f6914af
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
52 additions
and
15 deletions
+52
-15
.gitignore
.gitignore
+5
-1
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+7
-3
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+27
-3
tensorpack/user_ops/Makefile
tensorpack/user_ops/Makefile
+1
-1
tensorpack/user_ops/__init__.py
tensorpack/user_ops/__init__.py
+3
-4
tensorpack/user_ops/zmq_recv_op.cc
tensorpack/user_ops/zmq_recv_op.cc
+1
-1
tensorpack/utils/serialize.py
tensorpack/utils/serialize.py
+8
-2
No files found.
.gitignore
View file @
b5c5a944
...
...
@@ -73,5 +73,9 @@ model-*
checkpoint
*.json
*.prototxt
snippet
*.txt
# my personal stuff
snippet
examples/private
TODO.md
tensorpack/dataflow/remote.py
View file @
b5c5a944
...
...
@@ -7,7 +7,7 @@ import time
from
collections
import
deque
from
.base
import
DataFlow
from
..utils
import
logger
,
get_tqdm
from
..utils.serialize
import
dumps
,
loads
from
..utils.serialize
import
dumps
,
loads
,
dumps_for_tfop
try
:
import
zmq
except
ImportError
:
...
...
@@ -17,7 +17,7 @@ else:
__all__
=
[
'send_dataflow_zmq'
,
'RemoteDataZMQ'
]
def
send_dataflow_zmq
(
df
,
addr
,
hwm
=
50
,
print_interval
=
100
):
def
send_dataflow_zmq
(
df
,
addr
,
hwm
=
50
,
print_interval
=
100
,
format
=
'msgpack'
):
"""
Run DataFlow and send data to a ZMQ socket addr.
It will dump and send each datapoint to this addr with a PUSH socket.
...
...
@@ -26,7 +26,11 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100):
df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket addr.
hwm (int): high water mark
format (str): The serialization format.
'msgpack' is the default format corresponding to RemoteDataZMQ.
Otherwise will use the format corresponding to the ZMQRecv TensorFlow Op.
"""
dump_fn
=
dumps
if
format
==
'msgpack'
else
dumps_for_tfop
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PUSH
)
socket
.
set_hwm
(
hwm
)
...
...
@@ -39,7 +43,7 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100):
while
True
:
for
dp
in
df
.
get_data
():
start
=
time
.
time
()
socket
.
send
(
dump
s
(
dp
),
copy
=
False
)
socket
.
send
(
dump
_fn
(
dp
),
copy
=
False
)
q
.
append
(
time
.
time
()
-
start
)
pbar
.
update
(
1
)
if
pbar
.
n
%
print_interval
==
0
:
...
...
tensorpack/train/input_data.py
View file @
b5c5a944
...
...
@@ -16,7 +16,8 @@ from ..callbacks.concurrency import StartProcOrThread
__all__
=
[
'InputData'
,
'FeedfreeInput'
,
'QueueInput'
,
'BatchQueueInput'
,
'TensorInput'
,
'DummyConstantInput'
]
'ZMQInput'
,
'DummyConstantInput'
,
'TensorInput'
]
@
six
.
add_metaclass
(
ABCMeta
)
...
...
@@ -154,7 +155,7 @@ class QueueInput(FeedfreeInput):
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
input_placehdrs
)
def
setup_training
(
self
,
trainer
):
s
elf
.
setup
(
trainer
.
model
)
s
uper
(
QueueInput
,
self
)
.
setup_training
(
trainer
)
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_input_tensors
(
self
):
...
...
@@ -218,7 +219,7 @@ class BatchQueueInput(FeedfreeInput):
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
placehdrs_nobatch
)
def
setup_training
(
self
,
trainer
):
s
elf
.
setup
(
trainer
.
model
)
s
uper
(
BatchQueueInput
,
self
)
.
setup_training
(
trainer
)
trainer
.
register_callback
(
StartProcOrThread
(
self
.
thread
))
def
get_input_tensors
(
self
):
...
...
@@ -282,3 +283,26 @@ class TensorInput(FeedfreeInput):
def
get_input_tensors
(
self
):
return
self
.
get_tensor_fn
()
class
ZMQInput
(
FeedfreeInput
):
def
__init__
(
self
,
endpoint
):
self
.
_endpoint
=
endpoint
def
size
(
self
):
raise
NotImplementedError
()
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"ZMQInput has to be used with input placeholders!"
def
get_input_tensors
(
self
):
from
tensorpack.user_ops
import
zmq_recv
ret
=
zmq_recv
(
self
.
_endpoint
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
])
if
isinstance
(
self
.
_recv
,
tf
.
Tensor
):
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_placehdrs
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_placehdrs
):
qv
.
set_shape
(
v
.
get_shape
())
return
ret
tensorpack/user_ops/Makefile
View file @
b5c5a944
...
...
@@ -55,5 +55,5 @@ $(OBJ_DIR)/%.d: %.cc Makefile
@
$(CXX)
$(CXXFLAGS)
-MM
-MT
"
$(OBJ_DIR)
/
$
(<:.cc=.o)
$(OBJ_DIR)
/
$
(<:.cc=.d)"
"
$<
"
>
"
$@
"
||
rm
"
$@
"
clean
:
@
rm
-rvf
$(OBJ_DIR)
@
rm
-rvf
$(OBJ_DIR)
$(SO)
tensorpack/user_ops/__init__.py
View file @
b5c5a944
...
...
@@ -16,7 +16,6 @@ print("Compiling user ops ...")
ret
=
os
.
system
(
compile_cmd
)
if
ret
!=
0
:
print
(
"Failed to compile user ops!"
)
recv_mod
=
tf
.
load_op_library
(
os
.
path
.
join
(
file_dir
,
'zmq_recv_op.so'
))
zmq_recv
=
recv_mod
.
zmq_recv
else
:
recv_mod
=
tf
.
load_op_library
(
os
.
path
.
join
(
file_dir
,
'zmq_recv_op.so'
))
zmq_recv
=
recv_mod
.
zmq_recv
tensorpack/user_ops/zmq_recv_op.cc
View file @
b5c5a944
//File: recv_op.cc
//File:
zmq_
recv_op.cc
//Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#include <string>
...
...
tensorpack/utils/serialize.py
View file @
b5c5a944
...
...
@@ -15,7 +15,8 @@ from tensorflow.core.framework import types_pb2 as DataType
msgpack_numpy
.
patch
()
__all__
=
[
'loads'
,
'dumps'
]
__all__
=
[
'loads'
,
'dumps'
,
'dumps_for_tfop'
,
'dump_tensor_protos'
,
'to_tensor_proto'
]
def
dumps
(
obj
):
...
...
@@ -46,7 +47,7 @@ _DTYPE_DICT = {
_DTYPE_DICT
=
{
np
.
dtype
(
k
):
v
for
k
,
v
in
_DTYPE_DICT
.
items
()}
# TODO support string tensor
# TODO support string tensor
and scalar
def
to_tensor_proto
(
arr
):
"""
Convert a numpy array to TensorProto
...
...
@@ -86,3 +87,8 @@ def dump_tensor_protos(protos):
s
+=
struct
.
pack
(
'=i'
,
len
(
buf
))
# won't send stuff over 2G
s
+=
buf
return
s
def
dumps_for_tfop
(
dp
):
protos
=
[
to_tensor_proto
(
arr
)
for
arr
in
dp
]
return
dump_tensor_protos
(
protos
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment