Commit 8d3c709c authored by Yuxin Wu's avatar Yuxin Wu

py3 compat & simulator speedup

parent 4fc21080
...@@ -65,18 +65,18 @@ class AtariPlayer(RLEnvironment): ...@@ -65,18 +65,18 @@ class AtariPlayer(RLEnvironment):
self.ale = ALEInterface() self.ale = ALEInterface()
self.rng = get_rng(self) self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(0, 10000)) self.ale.setInt(b"random_seed", self.rng.randint(0, 10000))
self.ale.setBool("showinfo", False) self.ale.setBool(b"showinfo", False)
self.ale.setInt("frame_skip", 1) self.ale.setInt(b"frame_skip", 1)
self.ale.setBool('color_averaging', False) self.ale.setBool(b'color_averaging', False)
# manual.pdf suggests otherwise. # manual.pdf suggests otherwise.
self.ale.setFloat('repeat_action_probability', 0.0) self.ale.setFloat(b'repeat_action_probability', 0.0)
# viz setup # viz setup
if isinstance(viz, six.string_types): if isinstance(viz, six.string_types):
assert os.path.isdir(viz), viz assert os.path.isdir(viz), viz
self.ale.setString('record_screen_dir', viz) self.ale.setString(b'record_screen_dir', viz)
viz = 0 viz = 0
if isinstance(viz, int): if isinstance(viz, int):
viz = float(viz) viz = float(viz)
...@@ -86,7 +86,7 @@ class AtariPlayer(RLEnvironment): ...@@ -86,7 +86,7 @@ class AtariPlayer(RLEnvironment):
cv2.startWindowThread() cv2.startWindowThread()
cv2.namedWindow(self.windowname) cv2.namedWindow(self.windowname)
self.ale.loadROM(rom_file) self.ale.loadROM(rom_file.encode('utf-8'))
self.width, self.height = self.ale.getScreenDims() self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet() self.actions = self.ale.getMinimalActionSet()
...@@ -184,7 +184,7 @@ if __name__ == '__main__': ...@@ -184,7 +184,7 @@ if __name__ == '__main__':
cnt += 1 cnt += 1
if cnt == 5000: if cnt == 5000:
break break
print time.time() - start print(time.time() - start)
if len(sys.argv) == 3 and sys.argv[2] == 'benchmark': if len(sys.argv) == 3 and sys.argv[2] == 'benchmark':
import threading, multiprocessing import threading, multiprocessing
......
...@@ -39,33 +39,30 @@ class SimulatorProcess(multiprocessing.Process): ...@@ -39,33 +39,30 @@ class SimulatorProcess(multiprocessing.Process):
self.c2s = pipe_c2s self.c2s = pipe_c2s
self.s2c = pipe_s2c self.s2c = pipe_s2c
self.identity = u'simulator-{}'.format(self.idx).encode('utf-8')
def run(self): def run(self):
player = self._build_player() player = self._build_player()
context = zmq.Context() context = zmq.Context()
c2s_socket = context.socket(zmq.DEALER) c2s_socket = context.socket(zmq.PUSH)
c2s_socket.identity = 'simulator-{}'.format(self.idx) c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
c2s_socket.set_hwm(2) c2s_socket.set_hwm(2)
c2s_socket.connect(self.c2s) c2s_socket.connect(self.c2s)
s2c_socket = context.socket(zmq.DEALER) s2c_socket = context.socket(zmq.DEALER)
s2c_socket.identity = 'simulator-{}'.format(self.idx) s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
#s2c_socket.set_hwm(5) #s2c_socket.set_hwm(5)
s2c_socket.connect(self.s2c) s2c_socket.connect(self.s2c)
#cnt = 0 state = player.current_state()
reward, isOver = 0, False
while True: while True:
state = player.current_state() c2s_socket.send(dumps(
c2s_socket.send(dumps(state), copy=False) (self.identity, state, reward, isOver)),
#with total_timer('client recv_action'): copy=False)
data = s2c_socket.recv(copy=False) action = loads(s2c_socket.recv(copy=False))
action = loads(data)
reward, isOver = player.action(action) reward, isOver = player.action(action)
c2s_socket.send(dumps((reward, isOver)), copy=False) state = player.current_state()
#with total_timer('client recv_ack'):
ACK = s2c_socket.recv(copy=False)
#cnt += 1
#if cnt % 100 == 0:
#print_total_timer()
@abstractmethod @abstractmethod
def _build_player(self): def _build_player(self):
...@@ -80,7 +77,6 @@ class SimulatorMaster(threading.Thread): ...@@ -80,7 +77,6 @@ class SimulatorMaster(threading.Thread):
class ClientState(object): class ClientState(object):
def __init__(self): def __init__(self):
self.protocol_state = 0 # state in communication
self.memory = [] # list of Experience self.memory = [] # list of Experience
class Experience(object): class Experience(object):
...@@ -95,21 +91,25 @@ class SimulatorMaster(threading.Thread): ...@@ -95,21 +91,25 @@ class SimulatorMaster(threading.Thread):
def __init__(self, pipe_c2s, pipe_s2c): def __init__(self, pipe_c2s, pipe_s2c):
super(SimulatorMaster, self).__init__() super(SimulatorMaster, self).__init__()
self.daemon = True
self.context = zmq.Context() self.context = zmq.Context()
self.c2s_socket = self.context.socket(zmq.ROUTER) self.c2s_socket = self.context.socket(zmq.PULL)
self.c2s_socket.bind(pipe_c2s) self.c2s_socket.bind(pipe_c2s)
self.c2s_socket.set_hwm(10)
self.s2c_socket = self.context.socket(zmq.ROUTER) self.s2c_socket = self.context.socket(zmq.ROUTER)
self.s2c_socket.bind(pipe_s2c) self.s2c_socket.bind(pipe_s2c)
self.s2c_socket.set_hwm(10)
self.socket_lock = threading.Lock()
self.daemon = True
# queueing messages to client # queueing messages to client
self.send_queue = queue.Queue(maxsize=100) self.send_queue = queue.Queue(maxsize=100)
self.send_thread = LoopThread(lambda:
self.s2c_socket.send_multipart(self.send_queue.get())) def f():
msg = self.send_queue.get()
# slow
self.s2c_socket.send_multipart(msg, copy=False)
self.send_thread = LoopThread(f)
self.send_thread.daemon = True self.send_thread.daemon = True
self.send_thread.start() self.send_thread.start()
...@@ -123,21 +123,25 @@ class SimulatorMaster(threading.Thread): ...@@ -123,21 +123,25 @@ class SimulatorMaster(threading.Thread):
def run(self): def run(self):
self.clients = defaultdict(self.ClientState) self.clients = defaultdict(self.ClientState)
#cnt = 0
while True: while True:
ident, msg = self.c2s_socket.recv_multipart() #cnt += 1
#if cnt % 3000 == 0:
#print_total_timer()
msg = loads(self.c2s_socket.recv(copy=False).bytes)
ident, state, reward, isOver = msg
client = self.clients[ident] client = self.clients[ident]
client.protocol_state = 1 - client.protocol_state # first flip the state
if not client.protocol_state == 0: # state-action # check if reward&isOver is valid
state = loads(msg) # in the first message, only state is valid
self._on_state(state, ident) if len(client.memory) > 0:
else: # reward-response
reward, isOver = loads(msg)
client.memory[-1].reward = reward client.memory[-1].reward = reward
if isOver: if isOver:
self._on_episode_over(ident) self._on_episode_over(ident)
else: else:
self._on_datapoint(ident) self._on_datapoint(ident)
self.send_queue.put([ident, 'Thanks']) # just an ACK # feed state and return action
self._on_state(state, ident)
@abstractmethod @abstractmethod
def _on_state(self, state, ident): def _on_state(self, state, ident):
......
...@@ -39,9 +39,9 @@ class AsyncPredictorBase(PredictorBase): ...@@ -39,9 +39,9 @@ class AsyncPredictorBase(PredictorBase):
""" """
:param dp: A data point (list of component) as inputs. :param dp: A data point (list of component) as inputs.
(It should be either batched or not batched depending on the predictor implementation) (It should be either batched or not batched depending on the predictor implementation)
:param callback: a thread-safe callback to get called with the list of :param callback: a thread-safe callback to get called with
outputs of (inputs, outputs) pair either outputs or (inputs, outputs)
:return: a Future of outputs :return: a Future of results
""" """
@abstractmethod @abstractmethod
......
...@@ -82,16 +82,16 @@ class PredictorWorkerThread(threading.Thread): ...@@ -82,16 +82,16 @@ class PredictorWorkerThread(threading.Thread):
self.id = id self.id = id
def run(self): def run(self):
#self.xxx = None
while True: while True:
batched, futures = self.fetch_batch() batched, futures = self.fetch_batch()
outputs = self.func(batched) outputs = self.func(batched)
#print "batched size: ", len(batched[0]), "queuesize: ", self.queue.qsize() #print "Worker {} batched {} Queue {}".format(
#self.id, len(futures), self.queue.qsize())
# debug, for speed testing # debug, for speed testing
#if self.xxx is None: #if not hasattr(self, 'xxx'):
#self.xxx = outputs = self.func([batched]) #self.xxx = outputs = self.func(batched)
#else: #else:
#outputs = [[self.xxx[0][0]] * len(batched), [self.xxx[1][0]] * len(batched)] #outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for idx, f in enumerate(futures): for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs]) f.set_result([k[idx] for k in outputs])
...@@ -125,7 +125,9 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase): ...@@ -125,7 +125,9 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
""" :param predictors: a list of OnlinePredictor""" """ :param predictors: a list of OnlinePredictor"""
for k in predictors: for k in predictors:
assert isinstance(k, OnlinePredictor), type(k) assert isinstance(k, OnlinePredictor), type(k)
self.input_queue = queue.Queue(maxsize=len(predictors)*10) # TODO use predictors.return_input here
assert k.return_input == False
self.input_queue = queue.Queue(maxsize=len(predictors)*100)
self.threads = [ self.threads = [
PredictorWorkerThread( PredictorWorkerThread(
self.input_queue, f, id, batch_size=batch_size) self.input_queue, f, id, batch_size=batch_size)
......
...@@ -115,7 +115,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -115,7 +115,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
train_op = self.config.optimizer.apply_gradients(grad_list[k]) train_op = self.config.optimizer.apply_gradients(grad_list[k])
def f(op=train_op): # avoid late-binding def f(op=train_op): # avoid late-binding
self.sess.run([op]) self.sess.run([op])
self.async_step_counter.next() next(self.async_step_counter)
th = LoopThread(f) th = LoopThread(f)
th.pause() th.pause()
th.start() th.start()
...@@ -127,7 +127,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer): ...@@ -127,7 +127,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
self.async_running = True self.async_running = True
for th in self.training_threads: # resume all threads for th in self.training_threads: # resume all threads
th.resume() th.resume()
self.async_step_counter.next() next(self.async_step_counter)
super(AsyncMultiGPUTrainer, self).run_step() super(AsyncMultiGPUTrainer, self).run_step()
def _trigger_epoch(self): def _trigger_epoch(self):
......
...@@ -202,6 +202,9 @@ class QueueInputTrainer(Trainer): ...@@ -202,6 +202,9 @@ class QueueInputTrainer(Trainer):
self.config.optimizer.apply_gradients(grads, get_global_step_var()), self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op') summary_moving_average(), name='train_op')
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
self.main_loop() self.main_loop()
def run_step(self): def run_step(self):
...@@ -218,8 +221,6 @@ class QueueInputTrainer(Trainer): ...@@ -218,8 +221,6 @@ class QueueInputTrainer(Trainer):
#trace_file.write(trace.generate_chrome_trace_format()) #trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit() #import sys; sys.exit()
#self.sess.run([self.dequed_inputs[1]])
def _trigger_epoch(self): def _trigger_epoch(self):
# need to run summary_op every epoch # need to run summary_op every epoch
# note that summary_op will take a data from the queue # note that summary_op will take a data from the queue
...@@ -234,3 +235,5 @@ class QueueInputTrainer(Trainer): ...@@ -234,3 +235,5 @@ class QueueInputTrainer(Trainer):
""" """
return self.predictor_factory.get_predictor(input_names, output_names, tower) return self.predictor_factory.get_predictor(input_names, output_names, tower)
def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_names, output_names, k) for k in range(n)]
...@@ -13,7 +13,27 @@ import atexit ...@@ -13,7 +13,27 @@ import atexit
from .stat import StatCounter from .stat import StatCounter
from . import logger from . import logger
__all__ = ['total_timer', 'timed_operation', 'print_total_timer'] __all__ = ['total_timer', 'timed_operation',
'print_total_timer', 'IterSpeedCounter']
class IterSpeedCounter(object):
def __init__(self, print_every, name=None):
self.cnt = 0
self.print_every = int(print_every)
self.name = name if name else 'IterSpeed'
def reset(self):
self.start = time.time()
def __call__(self):
if self.cnt == 0:
self.reset()
self.cnt += 1
if self.cnt % self.print_every != 0:
return
t = time.time() - self.start
logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format(
self.name, t, self.cnt, t / self.cnt))
@contextmanager @contextmanager
def timed_operation(msg, log_start=False): def timed_operation(msg, log_start=False):
...@@ -37,7 +57,7 @@ def print_total_timer(): ...@@ -37,7 +57,7 @@ def print_total_timer():
if len(_TOTAL_TIMER_DATA) == 0: if len(_TOTAL_TIMER_DATA) == 0:
return return
for k, v in six.iteritems(_TOTAL_TIMER_DATA): for k, v in six.iteritems(_TOTAL_TIMER_DATA):
logger.info("Total Time: {} -> {} sec, {} times, {} sec/time".format( logger.info("Total Time: {} -> {:.2f} sec, {} times, {:.3g} sec/time".format(
k, v.sum, v.count, v.average)) k, v.sum, v.count, v.average))
atexit.register(print_total_timer) atexit.register(print_total_timer)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment