Commit 542e00f2 authored by Yuxin Wu's avatar Yuxin Wu

small fix

parent e28d616e
# Steps to reproduce: # To run pretrained model:
1. install [tensorpack](https://github.com/ppwwyyxx/tensorpack) 1. install [tensorpack](https://github.com/ppwwyyxx/tensorpack)
2. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk) 2. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk)
3. `ENV=NAME_OF_ENV ./run-atari.py --load "$ENV".tfmodel --env "$ENV"` 3. `ENV=NAME_OF_ENV ./run-atari.py --load "$ENV".tfmodel --env "$ENV"`
<!--
-Models are available for the following gym environments:
-
-+ [Breakout-v0](https://gym.openai.com/envs/Breakout-v0)
-->
Note that atari game settings in gym is very different from DeepMind papers, therefore the scores are not comparable.
...@@ -34,7 +34,6 @@ def get_player(viz=False, train=False, dumpdir=None): ...@@ -34,7 +34,6 @@ def get_player(viz=False, train=False, dumpdir=None):
pl = HistoryFramePlayer(pl, FRAME_HISTORY) pl = HistoryFramePlayer(pl, FRAME_HISTORY)
if not train: if not train:
pl = PreventStuckPlayer(pl, 30, 1) pl = PreventStuckPlayer(pl, 30, 1)
pl = LimitLengthPlayer(pl, 40000)
return pl return pl
class MySimulatorWorker(SimulatorProcess): class MySimulatorWorker(SimulatorProcess):
......
...@@ -33,7 +33,7 @@ with tf.Graph().as_default() as G: ...@@ -33,7 +33,7 @@ with tf.Graph().as_default() as G:
init = sessinit.ParamRestore(np.load(args.model).item()) init = sessinit.ParamRestore(np.load(args.model).item())
else: else:
init = sessinit.SaverRestore(args.model) init = sessinit.SaverRestore(args.model)
sess = tf.Session() sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
init.init(sess) init.init(sess)
# dump ... # dump ...
......
...@@ -88,6 +88,12 @@ class DiscreteActionSpace(ActionSpace): ...@@ -88,6 +88,12 @@ class DiscreteActionSpace(ActionSpace):
def num_actions(self): def num_actions(self):
return self.num return self.num
def __repr__(self):
return "DiscreteActionSpace({})".format(self.num)
def __str__(self):
return "DiscreteActionSpace({})".format(self.num)
class NaiveRLEnvironment(RLEnvironment): class NaiveRLEnvironment(RLEnvironment):
""" for testing only""" """ for testing only"""
def __init__(self): def __init__(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment