Commit 40cee0cb authored by Yuxin Wu's avatar Yuxin Wu

simplify code

parent 6e7eef04
......@@ -63,6 +63,10 @@ def get_player(viz=False, train=False):
live_lost_as_eoe=train)
global NUM_ACTIONS
NUM_ACTIONS = pl.get_num_actions()
if not train:
pl = HistoryFramePlayer(pl, FRAME_HISTORY)
pl = PreventStuckPlayer(pl, 30, 1)
return pl
class Model(ModelDesc):
......@@ -162,7 +166,7 @@ def play_one_episode(player, func, verbose=False):
return sc
def play_model(model_path):
player = PreventStuckPlayer(HistoryFramePlayer(get_player(0.013), FRAME_HISTORY), 30, 1)
player = get_player(0.013)
cfg = PredictConfig(
model=Model(),
input_data_mapping=[0],
......@@ -180,15 +184,10 @@ def eval_with_funcs(predict_funcs):
self.func = func
self.q = queue
def run(self):
player = PreventStuckPlayer(HistoryFramePlayer(get_player(), FRAME_HISTORY), 30, 1)
player = get_player()
while not self.stopped():
score = play_one_episode(player, self.func)
while not self.stopped():
try:
self.q.put(score, timeout=5)
break
except queue.Queue.Full:
pass
self.queue_put_stoppable(self.q, score)
q = queue.Queue()
threads = [Worker(f, q) for f in predict_funcs]
......@@ -201,9 +200,9 @@ def eval_with_funcs(predict_funcs):
for _ in tqdm(range(EVAL_EPISODE)):
r = q.get()
stat.feed(r)
finally:
for k in threads: k.stop()
for k in threads: k.join()
finally:
return (stat.average, stat.max)
def eval_model_multithread(model_path):
......
......@@ -43,7 +43,10 @@ class AtariPlayer(RLEnvironment):
self.ale.setInt("random_seed", self.rng.randint(0, 10000))
self.ale.setBool("showinfo", False)
#ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
try:
ALEInterface.setLoggerMode(ALEInterface.Logger.Warning)
except AttributeError:
logger.warn("https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!")
self.ale.setInt("frame_skip", 1)
self.ale.setBool('color_averaging', False)
# manual.pdf suggests otherwise. may need to check
......
......@@ -42,7 +42,7 @@ class PredictConfig(object):
input_data_mapping: [0] # the first component in a datapoint should map to `image_var`
:param model: a `ModelDesc` instance
:param output_var_names: a list of names of the output variables to predict, the
:param output_var_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param return_input: whether to produce (input, output) pair or just output. default to False.
......
......@@ -58,7 +58,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
describe_model()
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" A predictor worker that takes input and produces output by queue"""
""" An offline predictor worker that takes input and produces output by queue"""
def __init__(self, idx, gpuid, inqueue, outqueue, config):
"""
:param inqueue: input queue to get data point. elements are (task_id, dp)
......@@ -100,7 +100,7 @@ class MultiThreadAsyncPredictor(object):
"""
:param trainer: a `QueueInputTrainer` instance.
"""
self.input_queue = queue.Queue()
self.input_queue = queue.Queue(maxsize=nr_thread*2)
self.threads = [PredictorWorkerThread(self.input_queue, f)
for f in trainer.get_predict_funcs(input_names, output_names, nr_thread)]
......
......@@ -14,6 +14,7 @@ if six.PY2:
import subprocess32 as subprocess
else:
import subprocess
from six.moves import queue
from . import logger
......@@ -22,16 +23,38 @@ __all__ = ['StoppableThread', 'LoopThread', 'ensure_proc_terminate',
'start_proc_mask_signal']
class StoppableThread(threading.Thread):
"""
A thread that has a 'stop' event.
"""
def __init__(self):
super(StoppableThread, self).__init__()
self._stop = threading.Event()
def stop(self):
""" stop the thread"""
self._stop.set()
def stopped(self):
""" check whether the thread is stopped or not"""
return self._stop.isSet()
def queue_put_stoppable(self, q, obj):
""" put obj to queue, but will give up if the thread is stopped"""
while not self.stopped():
try:
q.put(obj, timeout=5)
break
except queue.Queue.Full:
pass
def queue_get_stoppable(self, q):
""" take obj from queue, but will give up if the thread is stopped"""
while not self.stopped():
try:
return q.get(timeout=5)
except queue.Queue.Full:
pass
class LoopThread(threading.Thread):
""" A pausable thread that simply runs a loop"""
def __init__(self, func):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment