Commit 8d3c709c authored by Yuxin Wu's avatar Yuxin Wu

py3 compat & simulator speedup

parent 4fc21080
......@@ -65,18 +65,18 @@ class AtariPlayer(RLEnvironment):
self.ale = ALEInterface()
self.rng = get_rng(self)
self.ale.setInt("random_seed", self.rng.randint(0, 10000))
self.ale.setBool("showinfo", False)
self.ale.setInt(b"random_seed", self.rng.randint(0, 10000))
self.ale.setBool(b"showinfo", False)
self.ale.setInt("frame_skip", 1)
self.ale.setBool('color_averaging', False)
self.ale.setInt(b"frame_skip", 1)
self.ale.setBool(b'color_averaging', False)
# manual.pdf suggests otherwise.
self.ale.setFloat('repeat_action_probability', 0.0)
self.ale.setFloat(b'repeat_action_probability', 0.0)
# viz setup
if isinstance(viz, six.string_types):
assert os.path.isdir(viz), viz
self.ale.setString('record_screen_dir', viz)
self.ale.setString(b'record_screen_dir', viz)
viz = 0
if isinstance(viz, int):
viz = float(viz)
......@@ -86,7 +86,7 @@ class AtariPlayer(RLEnvironment):
cv2.startWindowThread()
cv2.namedWindow(self.windowname)
self.ale.loadROM(rom_file)
self.ale.loadROM(rom_file.encode('utf-8'))
self.width, self.height = self.ale.getScreenDims()
self.actions = self.ale.getMinimalActionSet()
......@@ -184,7 +184,7 @@ if __name__ == '__main__':
cnt += 1
if cnt == 5000:
break
print time.time() - start
print(time.time() - start)
if len(sys.argv) == 3 and sys.argv[2] == 'benchmark':
import threading, multiprocessing
......
......@@ -39,33 +39,30 @@ class SimulatorProcess(multiprocessing.Process):
self.c2s = pipe_c2s
self.s2c = pipe_s2c
self.identity = u'simulator-{}'.format(self.idx).encode('utf-8')
def run(self):
player = self._build_player()
context = zmq.Context()
c2s_socket = context.socket(zmq.DEALER)
c2s_socket.identity = 'simulator-{}'.format(self.idx)
c2s_socket = context.socket(zmq.PUSH)
c2s_socket.setsockopt(zmq.IDENTITY, self.identity)
c2s_socket.set_hwm(2)
c2s_socket.connect(self.c2s)
s2c_socket = context.socket(zmq.DEALER)
s2c_socket.identity = 'simulator-{}'.format(self.idx)
s2c_socket.setsockopt(zmq.IDENTITY, self.identity)
#s2c_socket.set_hwm(5)
s2c_socket.connect(self.s2c)
#cnt = 0
while True:
state = player.current_state()
c2s_socket.send(dumps(state), copy=False)
#with total_timer('client recv_action'):
data = s2c_socket.recv(copy=False)
action = loads(data)
reward, isOver = 0, False
while True:
c2s_socket.send(dumps(
(self.identity, state, reward, isOver)),
copy=False)
action = loads(s2c_socket.recv(copy=False))
reward, isOver = player.action(action)
c2s_socket.send(dumps((reward, isOver)), copy=False)
#with total_timer('client recv_ack'):
ACK = s2c_socket.recv(copy=False)
#cnt += 1
#if cnt % 100 == 0:
#print_total_timer()
state = player.current_state()
@abstractmethod
def _build_player(self):
......@@ -80,7 +77,6 @@ class SimulatorMaster(threading.Thread):
class ClientState(object):
def __init__(self):
self.protocol_state = 0 # state in communication
self.memory = [] # list of Experience
class Experience(object):
......@@ -95,21 +91,25 @@ class SimulatorMaster(threading.Thread):
def __init__(self, pipe_c2s, pipe_s2c):
super(SimulatorMaster, self).__init__()
self.daemon = True
self.context = zmq.Context()
self.c2s_socket = self.context.socket(zmq.ROUTER)
self.c2s_socket = self.context.socket(zmq.PULL)
self.c2s_socket.bind(pipe_c2s)
self.c2s_socket.set_hwm(10)
self.s2c_socket = self.context.socket(zmq.ROUTER)
self.s2c_socket.bind(pipe_s2c)
self.socket_lock = threading.Lock()
self.daemon = True
self.s2c_socket.set_hwm(10)
# queueing messages to client
self.send_queue = queue.Queue(maxsize=100)
self.send_thread = LoopThread(lambda:
self.s2c_socket.send_multipart(self.send_queue.get()))
def f():
msg = self.send_queue.get()
# slow
self.s2c_socket.send_multipart(msg, copy=False)
self.send_thread = LoopThread(f)
self.send_thread.daemon = True
self.send_thread.start()
......@@ -123,21 +123,25 @@ class SimulatorMaster(threading.Thread):
def run(self):
self.clients = defaultdict(self.ClientState)
#cnt = 0
while True:
ident, msg = self.c2s_socket.recv_multipart()
#cnt += 1
#if cnt % 3000 == 0:
#print_total_timer()
msg = loads(self.c2s_socket.recv(copy=False).bytes)
ident, state, reward, isOver = msg
client = self.clients[ident]
client.protocol_state = 1 - client.protocol_state # first flip the state
if not client.protocol_state == 0: # state-action
state = loads(msg)
self._on_state(state, ident)
else: # reward-response
reward, isOver = loads(msg)
# check if reward&isOver is valid
# in the first message, only state is valid
if len(client.memory) > 0:
client.memory[-1].reward = reward
if isOver:
self._on_episode_over(ident)
else:
self._on_datapoint(ident)
self.send_queue.put([ident, 'Thanks']) # just an ACK
# feed state and return action
self._on_state(state, ident)
@abstractmethod
def _on_state(self, state, ident):
......
......@@ -39,9 +39,9 @@ class AsyncPredictorBase(PredictorBase):
"""
:param dp: A data point (list of component) as inputs.
(It should be either batched or not batched depending on the predictor implementation)
:param callback: a thread-safe callback to get called with the list of
outputs of (inputs, outputs) pair
:return: a Future of outputs
:param callback: a thread-safe callback to get called with
either outputs or (inputs, outputs)
:return: a Future of results
"""
@abstractmethod
......
......@@ -82,16 +82,16 @@ class PredictorWorkerThread(threading.Thread):
self.id = id
def run(self):
#self.xxx = None
while True:
batched, futures = self.fetch_batch()
outputs = self.func(batched)
#print "batched size: ", len(batched[0]), "queuesize: ", self.queue.qsize()
#print "Worker {} batched {} Queue {}".format(
#self.id, len(futures), self.queue.qsize())
# debug, for speed testing
#if self.xxx is None:
#self.xxx = outputs = self.func([batched])
#if not hasattr(self, 'xxx'):
#self.xxx = outputs = self.func(batched)
#else:
#outputs = [[self.xxx[0][0]] * len(batched), [self.xxx[1][0]] * len(batched)]
#outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs])
......@@ -125,7 +125,9 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
""" :param predictors: a list of OnlinePredictor"""
for k in predictors:
assert isinstance(k, OnlinePredictor), type(k)
self.input_queue = queue.Queue(maxsize=len(predictors)*10)
# TODO use predictors.return_input here
assert k.return_input == False
self.input_queue = queue.Queue(maxsize=len(predictors)*100)
self.threads = [
PredictorWorkerThread(
self.input_queue, f, id, batch_size=batch_size)
......
......@@ -115,7 +115,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
train_op = self.config.optimizer.apply_gradients(grad_list[k])
def f(op=train_op): # avoid late-binding
self.sess.run([op])
self.async_step_counter.next()
next(self.async_step_counter)
th = LoopThread(f)
th.pause()
th.start()
......@@ -127,7 +127,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
self.async_running = True
for th in self.training_threads: # resume all threads
th.resume()
self.async_step_counter.next()
next(self.async_step_counter)
super(AsyncMultiGPUTrainer, self).run_step()
def _trigger_epoch(self):
......
......@@ -202,6 +202,9 @@ class QueueInputTrainer(Trainer):
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
summary_moving_average(), name='train_op')
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
self.main_loop()
def run_step(self):
......@@ -218,8 +221,6 @@ class QueueInputTrainer(Trainer):
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
#self.sess.run([self.dequed_inputs[1]])
def _trigger_epoch(self):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
......@@ -234,3 +235,5 @@ class QueueInputTrainer(Trainer):
"""
return self.predictor_factory.get_predictor(input_names, output_names, tower)
def get_predict_funcs(self, input_names, output_names, n):
return [self.get_predict_func(input_names, output_names, k) for k in range(n)]
......@@ -13,7 +13,27 @@ import atexit
from .stat import StatCounter
from . import logger
__all__ = ['total_timer', 'timed_operation', 'print_total_timer']
__all__ = ['total_timer', 'timed_operation',
'print_total_timer', 'IterSpeedCounter']
class IterSpeedCounter(object):
def __init__(self, print_every, name=None):
self.cnt = 0
self.print_every = int(print_every)
self.name = name if name else 'IterSpeed'
def reset(self):
self.start = time.time()
def __call__(self):
if self.cnt == 0:
self.reset()
self.cnt += 1
if self.cnt % self.print_every != 0:
return
t = time.time() - self.start
logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format(
self.name, t, self.cnt, t / self.cnt))
@contextmanager
def timed_operation(msg, log_start=False):
......@@ -37,7 +57,7 @@ def print_total_timer():
if len(_TOTAL_TIMER_DATA) == 0:
return
for k, v in six.iteritems(_TOTAL_TIMER_DATA):
logger.info("Total Time: {} -> {} sec, {} times, {} sec/time".format(
logger.info("Total Time: {} -> {:.2f} sec, {} times, {:.3g} sec/time".format(
k, v.sum, v.count, v.average))
atexit.register(print_total_timer)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment