Commit f74ba9a1 authored by Yuxin Wu's avatar Yuxin Wu

fix some atari settings

parent 162f2db0
...@@ -62,8 +62,7 @@ class AtariPlayer(RLEnvironment): ...@@ -62,8 +62,7 @@ class AtariPlayer(RLEnvironment):
with _ALE_LOCK: with _ALE_LOCK:
self.ale = ALEInterface() self.ale = ALEInterface()
self.rng = get_rng(self) self.rng = get_rng(self)
self.ale.setInt(b"random_seed", self.rng.randint(0, 30000))
self.ale.setInt(b"random_seed", self.rng.randint(0, 10000))
self.ale.setBool(b"showinfo", False) self.ale.setBool(b"showinfo", False)
self.ale.setInt(b"frame_skip", 1) self.ale.setInt(b"frame_skip", 1)
...@@ -132,7 +131,8 @@ class AtariPlayer(RLEnvironment): ...@@ -132,7 +131,8 @@ class AtariPlayer(RLEnvironment):
def restart_episode(self): def restart_episode(self):
self.current_episode_score.reset() self.current_episode_score.reset()
self.ale.reset_game() with _ALE_LOCK:
self.ale.reset_game()
# random null-ops start # random null-ops start
n = self.rng.randint(self.nullop_start) n = self.rng.randint(self.nullop_start)
...@@ -160,11 +160,12 @@ class AtariPlayer(RLEnvironment): ...@@ -160,11 +160,12 @@ class AtariPlayer(RLEnvironment):
self.current_episode_score.feed(r) self.current_episode_score.feed(r)
isOver = self.ale.game_over() isOver = self.ale.game_over()
if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives
if isOver: if isOver:
self.finish_episode() self.finish_episode()
if self.ale.game_over():
self.restart_episode() self.restart_episode()
if self.live_lost_as_eoe:
isOver = isOver or newlives < oldlives
return (r, isOver) return (r, isOver)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -48,10 +48,11 @@ def eval_with_funcs(predict_funcs, nr_eval): ...@@ -48,10 +48,11 @@ def eval_with_funcs(predict_funcs, nr_eval):
return self._func(*args, **kwargs) return self._func(*args, **kwargs)
def run(self): def run(self):
player = get_player() player = get_player(train=False)
while not self.stopped(): while not self.stopped():
try: try:
score = play_one_episode(player, self.func) score = play_one_episode(player, self.func)
#print "Score, ", score
except RuntimeError: except RuntimeError:
return return
self.queue_put_stoppable(self.q, score) self.queue_put_stoppable(self.q, score)
......
...@@ -14,6 +14,7 @@ import numpy as np ...@@ -14,6 +14,7 @@ import numpy as np
import six import six
from six.moves import queue from six.moves import queue
from ..models._common import disable_layer_logging
from ..callbacks import Callback from ..callbacks import Callback
from ..tfutils.varmanip import SessionUpdate from ..tfutils.varmanip import SessionUpdate
from ..predict import OfflinePredictor from ..predict import OfflinePredictor
...@@ -221,6 +222,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF): ...@@ -221,6 +222,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
self.pred_config = pred_config self.pred_config = pred_config
def _prepare(self): def _prepare(self):
disable_layer_logging()
self.predictor = OfflinePredictor(self.pred_config) self.predictor = OfflinePredictor(self.pred_config)
with self.predictor.graph.as_default(): with self.predictor.graph.as_default():
vars_to_update = self._params_to_update() vars_to_update = self._params_to_update()
...@@ -244,6 +246,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF): ...@@ -244,6 +246,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
def _trigger_evt(self): def _trigger_evt(self):
with self.weight_lock: with self.weight_lock:
self.sess_updater.update(self.shared_dic['params']) self.sess_updater.update(self.shared_dic['params'])
logger.info("Updated.")
def _params_to_update(self): def _params_to_update(self):
# can be overwritten to update more params # can be overwritten to update more params
...@@ -262,7 +265,12 @@ class WeightSync(Callback): ...@@ -262,7 +265,12 @@ class WeightSync(Callback):
# can be overwritten to update more params # can be overwritten to update more params
return tf.trainable_variables() return tf.trainable_variables()
def _before_train(self):
self._sync()
def _trigger_epoch(self): def _trigger_epoch(self):
self._sync()
def _sync(self):
logger.info("Updating weights ...") logger.info("Updating weights ...")
dic = {v.name: v.eval() for v in self.vars} dic = {v.name: v.eval() for v in self.vars}
self.shared_dic['params'] = dic self.shared_dic['params'] = dic
......
...@@ -117,6 +117,9 @@ class EnqueueThread(threading.Thread): ...@@ -117,6 +117,9 @@ class EnqueueThread(threading.Thread):
try: try:
while True: while True:
for dp in self.dataflow.get_data(): for dp in self.dataflow.get_data():
#import IPython;
#IPython.embed(config=IPython.terminal.ipapp.load_default_config())
if self.coord.should_stop(): if self.coord.should_stop():
return return
feed = dict(zip(self.input_vars, dp)) feed = dict(zip(self.input_vars, dp))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment