Commit 28f36c44 authored by Yuxin Wu's avatar Yuxin Wu

fix bug on augmentors; bind=False in send_dataflow_zmq

parent c7e8ba78
...@@ -199,8 +199,8 @@ class AugmentImageComponents(MapData): ...@@ -199,8 +199,8 @@ class AugmentImageComponents(MapData):
copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa
with exception_handler.catch(): with exception_handler.catch():
major_image = index[0] # image to be used to get params. TODO better design? major_image = index[0] # image to be used to get params. TODO better design?
check_dtype(major_image)
im = copy_func(dp[major_image]) im = copy_func(dp[major_image])
check_dtype(im)
im, prms = self.augs._augment_return_params(im) im, prms = self.augs._augment_return_params(im)
dp[major_image] = im dp[major_image] = im
for idx in index[1:]: for idx in index[1:]:
......
...@@ -20,7 +20,7 @@ else: ...@@ -20,7 +20,7 @@ else:
__all__ = ['send_dataflow_zmq', 'RemoteDataZMQ'] __all__ = ['send_dataflow_zmq', 'RemoteDataZMQ']
def send_dataflow_zmq(df, addr, hwm=50, format=None): def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
""" """
Run DataFlow and send data to a ZMQ socket addr. Run DataFlow and send data to a ZMQ socket addr.
It will __connect__ to this addr, It will __connect__ to this addr,
...@@ -34,6 +34,7 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None): ...@@ -34,6 +34,7 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None):
format (str): The serialization format. format (str): The serialization format.
Default format would use :mod:`tensorpack.utils.serialize`. Default format would use :mod:`tensorpack.utils.serialize`.
An alternate format is 'zmq_op', used by https://github.com/tensorpack/zmq_ops. An alternate format is 'zmq_op', used by https://github.com/tensorpack/zmq_ops.
bind (bool): whether to bind or connect to the endpoint.
""" """
assert format in [None, 'zmq_op'] assert format in [None, 'zmq_op']
if format is None: if format is None:
...@@ -45,7 +46,10 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None): ...@@ -45,7 +46,10 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None):
ctx = zmq.Context() ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH) socket = ctx.socket(zmq.PUSH)
socket.set_hwm(hwm) socket.set_hwm(hwm)
socket.connect(addr) if bind:
socket.bind(addr)
else:
socket.connect(addr)
try: try:
df.reset_state() df.reset_state()
logger.info("Serving data to {} ...".format(addr)) logger.info("Serving data to {} ...".format(addr))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment