Commit b5c5a944 authored by Yuxin Wu's avatar Yuxin Wu

ZMQInput can run.

parent 6f6914af
......@@ -73,5 +73,9 @@ model-*
checkpoint
*.json
*.prototxt
snippet
*.txt
# my personal stuff
snippet
examples/private
TODO.md
......@@ -7,7 +7,7 @@ import time
from collections import deque
from .base import DataFlow
from ..utils import logger, get_tqdm
from ..utils.serialize import dumps, loads
from ..utils.serialize import dumps, loads, dumps_for_tfop
try:
import zmq
except ImportError:
......@@ -17,7 +17,7 @@ else:
__all__ = ['send_dataflow_zmq', 'RemoteDataZMQ']
def send_dataflow_zmq(df, addr, hwm=50, print_interval=100):
def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format='msgpack'):
"""
Run DataFlow and send data to a ZMQ socket addr.
It will dump and send each datapoint to this addr with a PUSH socket.
......@@ -26,7 +26,11 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100):
df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket addr.
hwm (int): high water mark
format (str): The serialization format.
'msgpack' is the default format corresponding to RemoteDataZMQ.
Otherwise will use the format corresponding to the ZMQRecv TensorFlow Op.
"""
dump_fn = dumps if format == 'msgpack' else dumps_for_tfop
ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH)
socket.set_hwm(hwm)
......@@ -39,7 +43,7 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100):
while True:
for dp in df.get_data():
start = time.time()
socket.send(dumps(dp), copy=False)
socket.send(dump_fn(dp), copy=False)
q.append(time.time() - start)
pbar.update(1)
if pbar.n % print_interval == 0:
......
......@@ -16,7 +16,8 @@ from ..callbacks.concurrency import StartProcOrThread
__all__ = ['InputData', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput',
'TensorInput', 'DummyConstantInput']
'ZMQInput',
'DummyConstantInput', 'TensorInput']
@six.add_metaclass(ABCMeta)
......@@ -154,7 +155,7 @@ class QueueInput(FeedfreeInput):
self.thread = EnqueueThread(self.queue, self.ds, self.input_placehdrs)
def setup_training(self, trainer):
self.setup(trainer.model)
super(QueueInput, self).setup_training(trainer)
trainer.register_callback(StartProcOrThread(self.thread))
def get_input_tensors(self):
......@@ -218,7 +219,7 @@ class BatchQueueInput(FeedfreeInput):
self.thread = EnqueueThread(self.queue, self.ds, placehdrs_nobatch)
def setup_training(self, trainer):
self.setup(trainer.model)
super(BatchQueueInput, self).setup_training(trainer)
trainer.register_callback(StartProcOrThread(self.thread))
def get_input_tensors(self):
......@@ -282,3 +283,26 @@ class TensorInput(FeedfreeInput):
def get_input_tensors(self):
return self.get_tensor_fn()
class ZMQInput(FeedfreeInput):
def __init__(self, endpoint):
self._endpoint = endpoint
def size(self):
raise NotImplementedError()
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"ZMQInput has to be used with input placeholders!"
def get_input_tensors(self):
from tensorpack.user_ops import zmq_recv
ret = zmq_recv(self._endpoint, [x.dtype for x in self.input_placehdrs])
if isinstance(self._recv, tf.Tensor):
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
qv.set_shape(v.get_shape())
return ret
......@@ -55,5 +55,5 @@ $(OBJ_DIR)/%.d: %.cc Makefile
@$(CXX) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cc=.o) $(OBJ_DIR)/$(<:.cc=.d)" "$<" > "$@" || rm "$@"
clean:
@rm -rvf $(OBJ_DIR)
@rm -rvf $(OBJ_DIR) $(SO)
......@@ -16,7 +16,6 @@ print("Compiling user ops ...")
ret = os.system(compile_cmd)
if ret != 0:
print("Failed to compile user ops!")
recv_mod = tf.load_op_library(os.path.join(file_dir, 'zmq_recv_op.so'))
zmq_recv = recv_mod.zmq_recv
else:
recv_mod = tf.load_op_library(os.path.join(file_dir, 'zmq_recv_op.so'))
zmq_recv = recv_mod.zmq_recv
//File: recv_op.cc
//File: zmq_recv_op.cc
//Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#include <string>
......
......@@ -15,7 +15,8 @@ from tensorflow.core.framework import types_pb2 as DataType
msgpack_numpy.patch()
__all__ = ['loads', 'dumps']
__all__ = ['loads', 'dumps', 'dumps_for_tfop', 'dump_tensor_protos',
'to_tensor_proto']
def dumps(obj):
......@@ -46,7 +47,7 @@ _DTYPE_DICT = {
_DTYPE_DICT = {np.dtype(k): v for k, v in _DTYPE_DICT.items()}
# TODO support string tensor
# TODO support string tensor and scalar
def to_tensor_proto(arr):
"""
Convert a numpy array to TensorProto
......@@ -86,3 +87,8 @@ def dump_tensor_protos(protos):
s += struct.pack('=i', len(buf)) # won't send stuff over 2G
s += buf
return s
def dumps_for_tfop(dp):
protos = [to_tensor_proto(arr) for arr in dp]
return dump_tensor_protos(protos)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment