Commit 40cee0cb authored by Yuxin Wu's avatar Yuxin Wu

simplify code

parent 6e7eef04
...@@ -63,6 +63,10 @@ def get_player(viz=False, train=False): ...@@ -63,6 +63,10 @@ def get_player(viz=False, train=False):
live_lost_as_eoe=train) live_lost_as_eoe=train)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = pl.get_num_actions() NUM_ACTIONS = pl.get_num_actions()
if not train:
pl = HistoryFramePlayer(pl, FRAME_HISTORY)
pl = PreventStuckPlayer(pl, 30, 1)
return pl return pl
class Model(ModelDesc): class Model(ModelDesc):
...@@ -162,7 +166,7 @@ def play_one_episode(player, func, verbose=False): ...@@ -162,7 +166,7 @@ def play_one_episode(player, func, verbose=False):
return sc return sc
def play_model(model_path): def play_model(model_path):
player = PreventStuckPlayer(HistoryFramePlayer(get_player(0.013), FRAME_HISTORY), 30, 1) player = get_player(0.013)
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
input_data_mapping=[0], input_data_mapping=[0],
...@@ -180,15 +184,10 @@ def eval_with_funcs(predict_funcs): ...@@ -180,15 +184,10 @@ def eval_with_funcs(predict_funcs):
self.func = func self.func = func
self.q = queue self.q = queue
def run(self): def run(self):
player = PreventStuckPlayer(HistoryFramePlayer(get_player(), FRAME_HISTORY), 30, 1) player = get_player()
while not self.stopped(): while not self.stopped():
score = play_one_episode(player, self.func) score = play_one_episode(player, self.func)
while not self.stopped(): self.queue_put_stoppable(self.q, score)
try:
self.q.put(score, timeout=5)
break
except queue.Queue.Full:
pass
q = queue.Queue() q = queue.Queue()
threads = [Worker(f, q) for f in predict_funcs] threads = [Worker(f, q) for f in predict_funcs]
...@@ -201,9 +200,9 @@ def eval_with_funcs(predict_funcs): ...@@ -201,9 +200,9 @@ def eval_with_funcs(predict_funcs):
for _ in tqdm(range(EVAL_EPISODE)): for _ in tqdm(range(EVAL_EPISODE)):
r = q.get() r = q.get()
stat.feed(r) stat.feed(r)
finally:
for k in threads: k.stop() for k in threads: k.stop()
for k in threads: k.join() for k in threads: k.join()
finally:
return (stat.average, stat.max) return (stat.average, stat.max)
def eval_model_multithread(model_path): def eval_model_multithread(model_path):
......
...@@ -43,7 +43,10 @@ class AtariPlayer(RLEnvironment): ...@@ -43,7 +43,10 @@ class AtariPlayer(RLEnvironment):
self.ale.setInt("random_seed", self.rng.randint(0, 10000)) self.ale.setInt("random_seed", self.rng.randint(0, 10000))
self.ale.setBool("showinfo", False) self.ale.setBool("showinfo", False)
#ALEInterface.setLoggerMode(ALEInterface.Logger.Warning) try:
ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
except AttributeError:
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!")
self.ale.setInt("frame_skip", 1) self.ale.setInt("frame_skip", 1)
self.ale.setBool('color_averaging', False) self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check # manual.pdf suggests otherwise. may need to check
......
...@@ -42,7 +42,7 @@ class PredictConfig(object): ...@@ -42,7 +42,7 @@ class PredictConfig(object):
input_data_mapping: [0] # the first component in a datapoint should map to `image_var` input_data_mapping: [0] # the first component in a datapoint should map to `image_var`
:param model: a `ModelDesc` instance :param model: a `ModelDesc` instance
:param output_var_names: a list of names of the output variables to predict, the :param output_var_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph. variables can be any computable tensor in the graph.
Predict specific output might not require all input variables. Predict specific output might not require all input variables.
:param return_input: whether to produce (input, output) pair or just output. default to False. :param return_input: whether to produce (input, output) pair or just output. default to False.
......
...@@ -58,7 +58,7 @@ class MultiProcessPredictWorker(multiprocessing.Process): ...@@ -58,7 +58,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
describe_model() describe_model()
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" A predictor worker that takes input and produces output by queue""" """ An offline predictor worker that takes input and produces output by queue"""
def __init__(self, idx, gpuid, inqueue, outqueue, config): def __init__(self, idx, gpuid, inqueue, outqueue, config):
""" """
:param inqueue: input queue to get data point. elements are (task_id, dp) :param inqueue: input queue to get data point. elements are (task_id, dp)
...@@ -100,7 +100,7 @@ class MultiThreadAsyncPredictor(object): ...@@ -100,7 +100,7 @@ class MultiThreadAsyncPredictor(object):
""" """
:param trainer: a `QueueInputTrainer` instance. :param trainer: a `QueueInputTrainer` instance.
""" """
self.input_queue = queue.Queue() self.input_queue = queue.Queue(maxsize=nr_thread*2)
self.threads = [PredictorWorkerThread(self.input_queue, f) self.threads = [PredictorWorkerThread(self.input_queue, f)
for f in trainer.get_predict_funcs(input_names, output_names, nr_thread)] for f in trainer.get_predict_funcs(input_names, output_names, nr_thread)]
......
...@@ -14,6 +14,7 @@ if six.PY2: ...@@ -14,6 +14,7 @@ if six.PY2:
import subprocess32 as subprocess import subprocess32 as subprocess
else: else:
import subprocess import subprocess
from six.moves import queue
from . import logger from . import logger
...@@ -22,16 +23,38 @@ __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate', ...@@ -22,16 +23,38 @@ __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'start_proc_mask_signal'] 'start_proc_mask_signal']
class StoppableThread(threading.Thread): class StoppableThread(threading.Thread):
"""
A thread that has a 'stop' event.
"""
def __init__(self): def __init__(self):
super(StoppableThread, self).__init__() super(StoppableThread, self).__init__()
self._stop = threading.Event() self._stop = threading.Event()
def stop(self): def stop(self):
""" stop the thread"""
self._stop.set() self._stop.set()
def stopped(self): def stopped(self):
""" check whether the thread is stopped or not"""
return self._stop.isSet() return self._stop.isSet()
def queue_put_stoppable(self, q, obj):
""" put obj to queue, but will give up if the thread is stopped"""
while not self.stopped():
try:
q.put(obj, timeout=5)
break
except queue.Queue.Full:
pass
def queue_get_stoppable(self, q):
""" take obj from queue, but will give up if the thread is stopped"""
while not self.stopped():
try:
return q.get(timeout=5)
except queue.Queue.Full:
pass
class LoopThread(threading.Thread): class LoopThread(threading.Thread):
""" A pausable thread that simply runs a loop""" """ A pausable thread that simply runs a loop"""
def __init__(self, func): def __init__(self, func):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment