Commit 61669a2e authored by Yuxin Wu's avatar Yuxin Wu

misc update

parent ce57a145
# To run pretrained model: # To run pretrained atari model for 100 episodes:
1. install [tensorpack](https://github.com/ppwwyyxx/tensorpack) 1. install [tensorpack](https://github.com/ppwwyyxx/tensorpack)
2. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk) 2. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk)
3. `ENV=NAME_OF_ENV ./run-atari.py --load "$ENV".tfmodel --env "$ENV"` 3. `ENV=NAME_OF_ENV ./run-atari.py --load "$ENV".tfmodel --env "$ENV"`
Models are available for the following gym environments: Models are available for the following gym atari environments:
+ [Breakout-v0](https://gym.openai.com/envs/Breakout-v0) + [Breakout-v0](https://gym.openai.com/envs/Breakout-v0)
+ [AirRaid-v0](https://gym.openai.com/envs/AirRaid-v0) + [AirRaid-v0](https://gym.openai.com/envs/AirRaid-v0)
+ [Asterix-v0](https://gym.openai.com/envs/Asterix-v0)
+ [Amidar-v0](https://gym.openai.com/envs/Asterix-v0)
+ [Seaquest-v0](https://gym.openai.com/envs/Seaquest-v0)
Note that atari game settings in gym is more difficult than the settings DeepMind papers, therefore the scores are not comparable. Note that atari game settings in gym is more difficult than the settings DeepMind papers, therefore the scores are not comparable.
...@@ -22,18 +22,14 @@ IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,) ...@@ -22,18 +22,14 @@ IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
NUM_ACTIONS = None NUM_ACTIONS = None
ENV_NAME = None ENV_NAME = None
def get_player(viz=False, train=False, dumpdir=None): def get_player(dumpdir=None):
pl = GymEnv(ENV_NAME, dumpdir=dumpdir) pl = GymEnv(ENV_NAME, dumpdir=dumpdir)
def func(img): pl = MapPlayerState(pl, lambda img: cv2.resize(img, IMAGE_SIZE[::-1]))
return cv2.resize(img, IMAGE_SIZE[::-1])
pl = MapPlayerState(pl, func)
global NUM_ACTIONS global NUM_ACTIONS
NUM_ACTIONS = pl.get_action_space().num_actions() NUM_ACTIONS = pl.get_action_space().num_actions()
pl = HistoryFramePlayer(pl, FRAME_HISTORY) pl = HistoryFramePlayer(pl, FRAME_HISTORY)
if not train:
pl = PreventStuckPlayer(pl, 30, 1)
return pl return pl
class MySimulatorWorker(SimulatorProcess): class MySimulatorWorker(SimulatorProcess):
...@@ -69,10 +65,6 @@ class Model(ModelDesc): ...@@ -69,10 +65,6 @@ class Model(ModelDesc):
policy = self._get_NN_prediction(state, is_training) policy = self._get_NN_prediction(state, is_training)
self.logits = tf.nn.softmax(policy, name='logits') self.logits = tf.nn.softmax(policy, name='logits')
def get_gradient_processor(self):
return [MapGradient(lambda grad: tf.clip_by_average_norm(grad, 0.1)),
SummaryGradient()]
def play_one_episode(player, func, verbose=False): def play_one_episode(player, func, verbose=False):
def f(s): def f(s):
spc = player.get_action_space() spc = player.get_action_space()
...@@ -109,5 +101,5 @@ if __name__ == '__main__': ...@@ -109,5 +101,5 @@ if __name__ == '__main__':
model=Model(), model=Model(),
session_init=SaverRestore(args.load), session_init=SaverRestore(args.load),
input_var_names=['state'], input_var_names=['state'],
output_var_names=['logits:0']) output_var_names=['logits'])
run_submission(cfg) run_submission(cfg)
...@@ -107,9 +107,9 @@ class MinSaver(Callback): ...@@ -107,9 +107,9 @@ class MinSaver(Callback):
"Cannot find a checkpoint state. Do you forget to use ModelSaver?") "Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = chpt.model_checkpoint_path path = chpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR, newname = os.path.join(logger.LOG_DIR,
'max_' if self.reverse else 'min_' + self.monitor_stat) 'max-' if self.reverse else 'min-' + self.monitor_stat)
shutil.copy(path, newname) shutil.copy(path, newname)
logger.info("Model with {} {} saved.".format( logger.info("Model with {} '{}' saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat)) 'maximum' if self.reverse else 'minimum', self.monitor_stat))
class MaxSaver(MinSaver): class MaxSaver(MinSaver):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment