Commit 335d6c28 authored by Yuxin Wu's avatar Yuxin Wu

faster serialize

parent 126fd26a
...@@ -43,7 +43,7 @@ class SimulatorProcess(multiprocessing.Process): ...@@ -43,7 +43,7 @@ class SimulatorProcess(multiprocessing.Process):
context = zmq.Context() context = zmq.Context()
c2s_socket = context.socket(zmq.DEALER) c2s_socket = context.socket(zmq.DEALER)
c2s_socket.identity = 'simulator-{}'.format(self.idx) c2s_socket.identity = 'simulator-{}'.format(self.idx)
#c2s_socket.set_hwm(2) c2s_socket.set_hwm(2)
c2s_socket.connect(self.c2s) c2s_socket.connect(self.c2s)
s2c_socket = context.socket(zmq.DEALER) s2c_socket = context.socket(zmq.DEALER)
...@@ -54,8 +54,7 @@ class SimulatorProcess(multiprocessing.Process): ...@@ -54,8 +54,7 @@ class SimulatorProcess(multiprocessing.Process):
#cnt = 0 #cnt = 0
while True: while True:
state = player.current_state() state = player.current_state()
#c2s_socket.send(dumps(state), copy=False) c2s_socket.send(dumps(state), copy=False)
c2s_socket.send('h')
#with total_timer('client recv_action'): #with total_timer('client recv_action'):
data = s2c_socket.recv(copy=False) data = s2c_socket.recv(copy=False)
action = loads(data) action = loads(data)
...@@ -126,8 +125,7 @@ class SimulatorMaster(threading.Thread): ...@@ -126,8 +125,7 @@ class SimulatorMaster(threading.Thread):
client = self.clients[ident] client = self.clients[ident]
client.protocol_state = 1 - client.protocol_state # first flip the state client.protocol_state = 1 - client.protocol_state # first flip the state
if not client.protocol_state == 0: # state-action if not client.protocol_state == 0: # state-action
#state = loads(msg) state = loads(msg)
state = np.zeros((84, 84, 4), dtype='float32')
self._on_state(state, ident) self._on_state(state, ident)
else: # reward-response else: # reward-response
reward, isOver = loads(msg) reward, isOver = loads(msg)
......
...@@ -112,6 +112,7 @@ class PredictorWorkerThread(threading.Thread): ...@@ -112,6 +112,7 @@ class PredictorWorkerThread(threading.Thread):
batched, futures = fetch() batched, futures = fetch()
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize() #print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
outputs = self.func([batched]) outputs = self.func([batched])
# debug, for speed testing
#if self.xxx is None: #if self.xxx is None:
#outputs = self.func([batched]) #outputs = self.func([batched])
#self.xxx = outputs #self.xxx = outputs
......
...@@ -3,17 +3,17 @@ ...@@ -3,17 +3,17 @@
# File: serialize.py # File: serialize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#import msgpack import msgpack
#import msgpack_numpy import msgpack_numpy
#msgpack_numpy.patch() msgpack_numpy.patch()
import dill #import dill
__all__ = ['loads', 'dumps'] __all__ = ['loads', 'dumps']
def dumps(obj): def dumps(obj):
return dill.dumps(obj) #return dill.dumps(obj)
#return msgpack.dumps(obj, use_bin_type=True) return msgpack.dumps(obj, use_bin_type=True)
def loads(buf): def loads(buf):
return dill.loads(buf) #return dill.loads(buf)
#return msgpack.loads(buf) return msgpack.loads(buf)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment