Commit 3dd7171d authored by Tal Sh's avatar Tal Sh Committed by Yuxin Wu

Removed excess copying when fetching data from ZMQ (#746)

* Removed excess copying when fetching data from ZMQ

The `.bytes` explicitly triggers a copy in order to create a Python `bytes` object. This isn't necessary for pyarrow deserialization.

* Remove all the `.bytes`!

* Fixed crash when deleting a MultiThreadPrefetchData instance when `reset_state()` was not called
parent a156f2bf
...@@ -295,7 +295,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -295,7 +295,7 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
self._size = -1 self._size = -1
def _recv(self): def _recv(self):
return loads(self.socket.recv(copy=False).bytes) return loads(self.socket.recv(copy=False))
def size(self): def size(self):
return self.ds.size() return self.ds.size()
...@@ -399,8 +399,9 @@ class MultiThreadPrefetchData(DataFlow): ...@@ -399,8 +399,9 @@ class MultiThreadPrefetchData(DataFlow):
def __del__(self): def __del__(self):
for p in self.threads: for p in self.threads:
p.stop() if p.is_alive():
p.join() p.stop()
p.join()
class PlasmaPutData(ProxyDataFlow): class PlasmaPutData(ProxyDataFlow):
......
...@@ -233,7 +233,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -233,7 +233,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
socket.connect(self.pipename) socket.connect(self.pipename)
while True: while True:
dp = loads(socket.recv(copy=False).bytes) dp = loads(socket.recv(copy=False))
dp = self.map_func(dp) dp = self.map_func(dp)
socket.send(dumps(dp), copy=False) socket.send(dumps(dp), copy=False)
...@@ -283,7 +283,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow): ...@@ -283,7 +283,7 @@ class MultiProcessMapDataZMQ(_ParallelMapData, _MultiProcessZMQDataFlow):
def _recv(self): def _recv(self):
msg = self.socket.recv_multipart(copy=False) msg = self.socket.recv_multipart(copy=False)
dp = loads(msg[1].bytes) dp = loads(msg[1])
return dp return dp
def get_data(self): def get_data(self):
......
...@@ -124,7 +124,7 @@ class RemoteDataZMQ(DataFlow): ...@@ -124,7 +124,7 @@ class RemoteDataZMQ(DataFlow):
self.bind_or_connect(socket, self._addr1) self.bind_or_connect(socket, self._addr1)
while True: while True:
dp = loads(socket.recv(copy=False).bytes) dp = loads(socket.recv(copy=False))
yield dp yield dp
self.cnt1 += 1 self.cnt1 += 1
else: else:
...@@ -143,7 +143,7 @@ class RemoteDataZMQ(DataFlow): ...@@ -143,7 +143,7 @@ class RemoteDataZMQ(DataFlow):
while True: while True:
evts = poller.poll() evts = poller.poll()
for sock, evt in evts: for sock, evt in evts:
dp = loads(sock.recv(copy=False).bytes) dp = loads(sock.recv(copy=False))
yield dp yield dp
if sock == socket1: if sock == socket1:
self.cnt1 += 1 self.cnt1 += 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment