Commit 28f36c44 authored by Yuxin Wu's avatar Yuxin Wu

fix bug on augmentors; bind=False in send_dataflow_zmq

parent c7e8ba78
......@@ -199,8 +199,8 @@ class AugmentImageComponents(MapData):
copy_func = copy_mod.deepcopy if copy else lambda x: x # noqa
with exception_handler.catch():
major_image = index[0] # image to be used to get params. TODO better design?
check_dtype(major_image)
im = copy_func(dp[major_image])
check_dtype(im)
im, prms = self.augs._augment_return_params(im)
dp[major_image] = im
for idx in index[1:]:
......
......@@ -20,7 +20,7 @@ else:
__all__ = ['send_dataflow_zmq', 'RemoteDataZMQ']
def send_dataflow_zmq(df, addr, hwm=50, format=None):
def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
"""
Run DataFlow and send data to a ZMQ socket addr.
It will __connect__ to this addr,
......@@ -34,6 +34,7 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None):
format (str): The serialization format.
Default format would use :mod:`tensorpack.utils.serialize`.
An alternate format is 'zmq_op', used by https://github.com/tensorpack/zmq_ops.
bind (bool): whether to bind or connect to the endpoint.
"""
assert format in [None, 'zmq_op']
if format is None:
......@@ -45,6 +46,9 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None):
ctx = zmq.Context()
socket = ctx.socket(zmq.PUSH)
socket.set_hwm(hwm)
if bind:
socket.bind(addr)
else:
socket.connect(addr)
try:
df.reset_state()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment