Commit b5c5a944 authored by Yuxin Wu's avatar Yuxin Wu

ZMQInput can run.

parent 6f6914af
...@@ -73,5 +73,9 @@ model-* ...@@ -73,5 +73,9 @@ model-*
checkpoint checkpoint
*.json *.json
*.prototxt *.prototxt
snippet
*.txt *.txt
# my personal stuff
snippet
examples/private
TODO.md
...@@ -7,7 +7,7 @@ import time ...@@ -7,7 +7,7 @@ import time
from collections import deque from collections import deque
from .base import DataFlow from .base import DataFlow
from ..utils import logger, get_tqdm from ..utils import logger, get_tqdm
from ..utils.serialize import dumps, loads from ..utils.serialize import dumps, loads, dumps_for_tfop
try: try:
import zmq import zmq
except ImportError: except ImportError:
...@@ -17,7 +17,7 @@ else: ...@@ -17,7 +17,7 @@ else:
__all__ = ['send_dataflow_zmq', 'RemoteDataZMQ'] __all__ = ['send_dataflow_zmq', 'RemoteDataZMQ']
def send_dataflow_zmq(df, addr, hwm=50, print_interval=100): def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format='msgpack'):
""" """
Run DataFlow and send data to a ZMQ socket addr. Run DataFlow and send data to a ZMQ socket addr.
It will dump and send each datapoint to this addr with a PUSH socket. It will dump and send each datapoint to this addr with a PUSH socket.
...@@ -26,7 +26,11 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100): ...@@ -26,7 +26,11 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100):
df (DataFlow): Will infinitely loop over the DataFlow. df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket addr. addr: a ZMQ socket addr.
hwm (int): high water mark hwm (int): high water mark
format (str): The serialization format.
'msgpack' is the default format corresponding to RemoteDataZMQ.
Otherwise will use the format corresponding to the ZMQRecv TensorFlow Op.
""" """
dump_fn = dumps if format == 'msgpack' else dumps_for_tfop
ctx = zmq.Context() ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH) socket = ctx.socket(zmq.PUSH)
socket.set_hwm(hwm) socket.set_hwm(hwm)
...@@ -39,7 +43,7 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100): ...@@ -39,7 +43,7 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100):
while True: while True:
for dp in df.get_data(): for dp in df.get_data():
start = time.time() start = time.time()
socket.send(dumps(dp), copy=False) socket.send(dump_fn(dp), copy=False)
q.append(time.time() - start) q.append(time.time() - start)
pbar.update(1) pbar.update(1)
if pbar.n % print_interval == 0: if pbar.n % print_interval == 0:
......
...@@ -16,7 +16,8 @@ from ..callbacks.concurrency import StartProcOrThread ...@@ -16,7 +16,8 @@ from ..callbacks.concurrency import StartProcOrThread
__all__ = ['InputData', 'FeedfreeInput', __all__ = ['InputData', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
'TensorInput', 'DummyConstantInput'] 'ZMQInput',
'DummyConstantInput', 'TensorInput']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
...@@ -154,7 +155,7 @@ class QueueInput(FeedfreeInput): ...@@ -154,7 +155,7 @@ class QueueInput(FeedfreeInput):
self.thread = EnqueueThread(self.queue, self.ds, self.input_placehdrs) self.thread = EnqueueThread(self.queue, self.ds, self.input_placehdrs)
def setup_training(self, trainer): def setup_training(self, trainer):
self.setup(trainer.model) super(QueueInput, self).setup_training(trainer)
trainer.register_callback(StartProcOrThread(self.thread)) trainer.register_callback(StartProcOrThread(self.thread))
def get_input_tensors(self): def get_input_tensors(self):
...@@ -218,7 +219,7 @@ class BatchQueueInput(FeedfreeInput): ...@@ -218,7 +219,7 @@ class BatchQueueInput(FeedfreeInput):
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch) self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
def setup_training(self, trainer): def setup_training(self, trainer):
self.setup(trainer.model) super(BatchQueueInput, self).setup_training(trainer)
trainer.register_callback(StartProcOrThread(self.thread)) trainer.register_callback(StartProcOrThread(self.thread))
def get_input_tensors(self): def get_input_tensors(self):
...@@ -282,3 +283,26 @@ class TensorInput(FeedfreeInput): ...@@ -282,3 +283,26 @@ class TensorInput(FeedfreeInput):
def get_input_tensors(self): def get_input_tensors(self):
return self.get_tensor_fn() return self.get_tensor_fn()
class ZMQInput(FeedfreeInput):
def __init__(self, endpoint):
self._endpoint = endpoint
def size(self):
raise NotImplementedError()
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"ZMQInput has to be used with input placeholders!"
def get_input_tensors(self):
from tensorpack.user_ops import zmq_recv
ret = zmq_recv(self._endpoint, [x.dtype for x in self.input_placehdrs])
if isinstance(self._recv, tf.Tensor):
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
qv.set_shape(v.get_shape())
return ret
...@@ -55,5 +55,5 @@ $(OBJ_DIR)/%.d: %.cc Makefile ...@@ -55,5 +55,5 @@ $(OBJ_DIR)/%.d: %.cc Makefile
@$(CXX) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cc=.o) $(OBJ_DIR)/$(<:.cc=.d)" "$<" > "$@" || rm "$@" @$(CXX) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cc=.o) $(OBJ_DIR)/$(<:.cc=.d)" "$<" > "$@" || rm "$@"
clean: clean:
@rm -rvf $(OBJ_DIR) @rm -rvf $(OBJ_DIR) $(SO)
...@@ -16,7 +16,6 @@ print("Compiling user ops ...") ...@@ -16,7 +16,6 @@ print("Compiling user ops ...")
ret = os.system(compile_cmd) ret = os.system(compile_cmd)
if ret != 0: if ret != 0:
print("Failed to compile user ops!") print("Failed to compile user ops!")
else:
recv_mod = tf.load_op_library(os.path.join(file_dir, 'zmq_recv_op.so'))
recv_mod = tf.load_op_library(os.path.join(file_dir, 'zmq_recv_op.so')) zmq_recv = recv_mod.zmq_recv
zmq_recv = recv_mod.zmq_recv
//File: recv_op.cc //File: zmq_recv_op.cc
//Author: Yuxin Wu <ppwwyyxxc@gmail.com> //Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#include <string> #include <string>
......
...@@ -15,7 +15,8 @@ from tensorflow.core.framework import types_pb2 as DataType ...@@ -15,7 +15,8 @@ from tensorflow.core.framework import types_pb2 as DataType
msgpack_numpy.patch() msgpack_numpy.patch()
__all__ = ['loads', 'dumps'] __all__ = ['loads', 'dumps', 'dumps_for_tfop', 'dump_tensor_protos',
'to_tensor_proto']
def dumps(obj): def dumps(obj):
...@@ -46,7 +47,7 @@ _DTYPE_DICT = { ...@@ -46,7 +47,7 @@ _DTYPE_DICT = {
_DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()} _DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()}
# TODO support string tensor # TODO support string tensor and scalar
def to_tensor_proto(arr): def to_tensor_proto(arr):
""" """
Convert a numpy array to TensorProto Convert a numpy array to TensorProto
...@@ -86,3 +87,8 @@ def dump_tensor_protos(protos): ...@@ -86,3 +87,8 @@ def dump_tensor_protos(protos):
s += struct.pack('=i', len(buf)) # won't send stuff over 2G s += struct.pack('=i', len(buf)) # won't send stuff over 2G
s += buf s += buf
return s return s
def dumps_for_tfop(dp):
protos = [to_tensor_proto(arr) for arr in dp]
return dump_tensor_protos(protos)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment