Commit 99b9c39c authored by Yuxin Wu's avatar Yuxin Wu

Add hwm option for RemoteData (#432)

parent 1f5c764d
...@@ -153,7 +153,7 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -153,7 +153,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
nr_proc (int): number of processes to use. nr_proc (int): number of processes to use.
hwm (int): the zmq "high-water mark" for both sender and receiver. hwm (int): the zmq "high-water mark" (queue size) for both sender and receiver.
""" """
assert os.name != 'nt', "PrefetchDataZMQ doesn't support windows! PrefetchData might work sometimes." assert os.name != 'nt', "PrefetchDataZMQ doesn't support windows! PrefetchData might work sometimes."
super(PrefetchDataZMQ, self).__init__(ds) super(PrefetchDataZMQ, self).__init__(ds)
......
...@@ -27,7 +27,7 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None): ...@@ -27,7 +27,7 @@ def send_dataflow_zmq(df, addr, hwm=50, print_interval=100, format=None):
Args: Args:
df (DataFlow): Will infinitely loop over the DataFlow. df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket addr. addr: a ZMQ socket addr.
hwm (int): high water mark hwm (int): ZMQ high-water mark (buffer size)
""" """
# format (str): The serialization format. ZMQ Op is still not publicly usable now # format (str): The serialization format. ZMQ Op is still not publicly usable now
# Default format would use :mod:`tensorpack.utils.serialize`. # Default format would use :mod:`tensorpack.utils.serialize`.
...@@ -65,16 +65,18 @@ class RemoteDataZMQ(DataFlow): ...@@ -65,16 +65,18 @@ class RemoteDataZMQ(DataFlow):
Attributes: Attributes:
cnt1, cnt2 (int): number of data points received from addr1 and addr2 cnt1, cnt2 (int): number of data points received from addr1 and addr2
""" """
def __init__(self, addr1, addr2=None): def __init__(self, addr1, addr2=None, hwm=50):
""" """
Args: Args:
addr1,addr2 (str): addr of the socket to connect to. addr1,addr2 (str): addr of the socket to connect to.
Use both if you need two protocols (e.g. both IPC and TCP). Use both if you need two protocols (e.g. both IPC and TCP).
I don't think you'll ever need 3. I don't think you'll ever need 3.
hwm (int): ZMQ high-water mark (buffer size)
""" """
assert addr1 assert addr1
self._addr1 = addr1 self._addr1 = addr1
self._addr2 = addr2 self._addr2 = addr2
self._hwm = int(hwm)
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
def reset_state(self): def reset_state(self):
...@@ -87,7 +89,7 @@ class RemoteDataZMQ(DataFlow): ...@@ -87,7 +89,7 @@ class RemoteDataZMQ(DataFlow):
ctx = zmq.Context() ctx = zmq.Context()
if self._addr2 is None: if self._addr2 is None:
socket = ctx.socket(zmq.PULL) socket = ctx.socket(zmq.PULL)
socket.set_hwm(50) socket.set_hwm(self._hwm)
socket.bind(self._addr1) socket.bind(self._addr1)
while True: while True:
...@@ -96,11 +98,11 @@ class RemoteDataZMQ(DataFlow): ...@@ -96,11 +98,11 @@ class RemoteDataZMQ(DataFlow):
self.cnt1 += 1 self.cnt1 += 1
else: else:
socket1 = ctx.socket(zmq.PULL) socket1 = ctx.socket(zmq.PULL)
socket1.set_hwm(50) socket1.set_hwm(self._hwm)
socket1.bind(self._addr1) socket1.bind(self._addr1)
socket2 = ctx.socket(zmq.PULL) socket2 = ctx.socket(zmq.PULL)
socket2.set_hwm(50) socket2.set_hwm(self._hwm)
socket2.bind(self._addr2) socket2.bind(self._addr2)
poller = zmq.Poller() poller = zmq.Poller()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment