Commit 61669a2e authored by Yuxin Wu's avatar Yuxin Wu

misc update

parent ce57a145
# To run pretrained model:
# To run pretrained atari model for 100 episodes:
1. install [tensorpack](https://github.com/ppwwyyxx/tensorpack)
2. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk)
3. `ENV=NAME_OF_ENV ./run-atari.py --load "$ENV".tfmodel --env "$ENV"`
Models are available for the following gym environments:
Models are available for the following gym atari environments:
+ [Breakout-v0](https://gym.openai.com/envs/Breakout-v0)
+ [AirRaid-v0](https://gym.openai.com/envs/AirRaid-v0)
+ [Asterix-v0](https://gym.openai.com/envs/Asterix-v0)
+ [Amidar-v0](https://gym.openai.com/envs/Asterix-v0)
+ [Seaquest-v0](https://gym.openai.com/envs/Seaquest-v0)
Note that atari game settings in gym is more difficult than the settings DeepMind papers, therefore the scores are not comparable.
......@@ -22,18 +22,14 @@ IMAGE_SHAPE3 = IMAGE_SIZE + (CHANNEL,)
NUM_ACTIONS = None
ENV_NAME = None
def get_player(viz=False, train=False, dumpdir=None):
def get_player(dumpdir=None):
pl = GymEnv(ENV_NAME, dumpdir=dumpdir)
def func(img):
return cv2.resize(img, IMAGE_SIZE[::-1])
pl = MapPlayerState(pl, func)
pl = MapPlayerState(pl, lambda img: cv2.resize(img, IMAGE_SIZE[::-1]))
global NUM_ACTIONS
NUM_ACTIONS = pl.get_action_space().num_actions()
pl = HistoryFramePlayer(pl, FRAME_HISTORY)
if not train:
pl = PreventStuckPlayer(pl, 30, 1)
return pl
class MySimulatorWorker(SimulatorProcess):
......@@ -69,10 +65,6 @@ class Model(ModelDesc):
policy = self._get_NN_prediction(state, is_training)
self.logits = tf.nn.softmax(policy, name='logits')
def get_gradient_processor(self):
return [MapGradient(lambda grad: tf.clip_by_average_norm(grad, 0.1)),
SummaryGradient()]
def play_one_episode(player, func, verbose=False):
def f(s):
spc = player.get_action_space()
......@@ -109,5 +101,5 @@ if __name__ == '__main__':
model=Model(),
session_init=SaverRestore(args.load),
input_var_names=['state'],
output_var_names=['logits:0'])
output_var_names=['logits'])
run_submission(cfg)
......@@ -107,9 +107,9 @@ class MinSaver(Callback):
"Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = chpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR,
'max_' if self.reverse else 'min_' + self.monitor_stat)
'max-' if self.reverse else 'min-' + self.monitor_stat)
shutil.copy(path, newname)
logger.info("Model with {} {} saved.".format(
logger.info("Model with {} '{}' saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat))
class MaxSaver(MinSaver):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment