Commit f74ba9a1 authored by Yuxin Wu's avatar Yuxin Wu

fix some atari settings

parent 162f2db0
......@@ -62,8 +62,7 @@ class AtariPlayer(RLEnvironment):
with _ALE_LOCK:
self.ale = ALEInterface()
self.rng = get_rng(self)
self.ale.setInt(b"random_seed", self.rng.randint(0, 10000))
self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
self.ale.setBool(b"showinfo", False)
self.ale.setInt(b"frame_skip", 1)
......@@ -132,7 +131,8 @@ class AtariPlayer(RLEnvironment):
def restart_episode(self):
self.current_episode_score.reset()
self.ale.reset_game()
with _ALE_LOCK:
self.ale.reset_game()
# random null-ops start
n = self.rng.randint(self.nullop_start)
......@@ -160,11 +160,12 @@ class AtariPlayer(RLEnvironment):
self.current_episode_score.feed(r)
isOver = self.ale.game_over()
if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives
if isOver:
self.finish_episode()
if self.ale.game_over():
self.restart_episode()
if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives
return (r, isOver)
if __name__ == '__main__':
......
......@@ -48,10 +48,11 @@ def eval_with_funcs(predict_funcs, nr_eval):
return self._func(*args, **kwargs)
def run(self):
player = get_player()
player = get_player(train=False)
while not self.stopped():
try:
score = play_one_episode(player, self.func)
#print "Score, ", score
except RuntimeError:
return
self.queue_put_stoppable(self.q, score)
......
......@@ -14,6 +14,7 @@ import numpy as np
import six
from six.moves import queue
from ..models._common import disable_layer_logging
from ..callbacks import Callback
from ..tfutils.varmanip import SessionUpdate
from ..predict import OfflinePredictor
......@@ -221,6 +222,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
self.pred_config = pred_config
def _prepare(self):
disable_layer_logging()
self.predictor = OfflinePredictor(self.pred_config)
with self.predictor.graph.as_default():
vars_to_update = self._params_to_update()
......@@ -244,6 +246,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
def _trigger_evt(self):
with self.weight_lock:
self.sess_updater.update(self.shared_dic['params'])
logger.info("Updated.")
def _params_to_update(self):
# can be overwritten to update more params
......@@ -262,7 +265,12 @@ class WeightSync(Callback):
# can be overwritten to update more params
return tf.trainable_variables()
def _before_train(self):
self._sync()
def _trigger_epoch(self):
self._sync()
def _sync(self):
logger.info("Updating weights ...")
dic = {v.name: v.eval() for v in self.vars}
self.shared_dic['params'] = dic
......
......@@ -117,6 +117,9 @@ class EnqueueThread(threading.Thread):
try:
while True:
for dp in self.dataflow.get_data():
#import IPython;
#IPython.embed(config=IPython.terminal.ipapp.load_default_config())
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment