Commit 55098813 authored by Yuxin Wu's avatar Yuxin Wu

Upgrade gym

parent 4bc0c748
...@@ -362,7 +362,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -362,7 +362,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
# include_init_with_doc doesn't work well for decorated init # include_init_with_doc doesn't work well for decorated init
# https://github.com/sphinx-doc/sphinx/issues/4258 # https://github.com/sphinx-doc/sphinx/issues/4258
return False return False
# hide deprecated stuff # Hide some names that are deprecated or not intended to be used
if name in [ if name in [
# deprecated stuff: # deprecated stuff:
'GaussianDeform', 'GaussianDeform',
......
This diff is collapsed.
...@@ -54,7 +54,7 @@ ENV_NAME = None ...@@ -54,7 +54,7 @@ ENV_NAME = None
def get_player(train=False, dumpdir=None): def get_player(train=False, dumpdir=None):
env = gym.make(ENV_NAME) env = gym.make(ENV_NAME)
if dumpdir: if dumpdir:
env = gym.wrappers.Monitor(env, dumpdir) env = gym.wrappers.Monitor(env, dumpdir, video_callable=lambda _: True)
env = FireResetEnv(env) env = FireResetEnv(env)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE)) env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
env = FrameStack(env, 4) env = FrameStack(env, 4)
...@@ -272,7 +272,7 @@ if __name__ == '__main__': ...@@ -272,7 +272,7 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--env', help='env', required=True) parser.add_argument('--env', help='env', required=True)
parser.add_argument('--task', help='task to perform', parser.add_argument('--task', help='task to perform',
choices=['play', 'eval', 'train', 'gen_submit'], default='train') choices=['play', 'eval', 'train', 'dump_video'], default='train')
parser.add_argument('--output', help='output directory for submission', default='output_dir') parser.add_argument('--output', help='output directory for submission', default='output_dir')
parser.add_argument('--episode', help='number of episode to eval', default=100, type=int) parser.add_argument('--episode', help='number of episode to eval', default=100, type=int)
args = parser.parse_args() args = parser.parse_args()
...@@ -297,10 +297,9 @@ if __name__ == '__main__': ...@@ -297,10 +297,9 @@ if __name__ == '__main__':
args.episode, render=True) args.episode, render=True)
elif args.task == 'eval': elif args.task == 'eval':
eval_model_multithread(pred, args.episode, get_player) eval_model_multithread(pred, args.episode, get_player)
elif args.task == 'gen_submit': elif args.task == 'dump_video':
play_n_episodes( play_n_episodes(
get_player(train=False, dumpdir=args.output), get_player(train=False, dumpdir=args.output),
pred, args.episode) pred, args.episode)
# gym.upload(args.output, api_key='xxx')
else: else:
train() train()
...@@ -9,7 +9,7 @@ import tensorpack ...@@ -9,7 +9,7 @@ import tensorpack
from tensorpack import ModelDesc, InputDesc from tensorpack import ModelDesc, InputDesc
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.tfutils import ( from tensorpack.tfutils import (
summary, get_current_tower_context, optimizer, gradproc) varreplace, summary, get_current_tower_context, optimizer, gradproc)
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
assert tensorpack.tfutils.common.get_tf_version_number() >= 1.2 assert tensorpack.tfutils.common.get_tf_version_number() >= 1.2
...@@ -60,7 +60,7 @@ class Model(ModelDesc): ...@@ -60,7 +60,7 @@ class Model(ModelDesc):
self.predict_value, 1), name='predict_reward') self.predict_value, 1), name='predict_reward')
summary.add_moving_summary(max_pred_reward) summary.add_moving_summary(max_pred_reward)
with tf.variable_scope('target'): with tf.variable_scope('target'), varreplace.freeze_variables(skip_collection=True):
targetQ_predict_value = self.get_DQN_prediction(next_state) # NxA targetQ_predict_value = self.get_DQN_prediction(next_state) # NxA
if self.method != 'Double': if self.method != 'Double':
...@@ -96,6 +96,6 @@ class Model(ModelDesc): ...@@ -96,6 +96,6 @@ class Model(ModelDesc):
target_name = v.op.name target_name = v.op.name
if target_name.startswith('target'): if target_name.startswith('target'):
new_name = target_name.replace('target/', '') new_name = target_name.replace('target/', '')
logger.info("{} <- {}".format(target_name, new_name)) logger.info("Target Network Update: {} <- {}".format(target_name, new_name))
ops.append(v.assign(G.get_tensor_by_name(new_name + ':0'))) ops.append(v.assign(G.get_tensor_by_name(new_name + ':0')))
return tf.group(*ops, name='update_target_network') return tf.group(*ops, name='update_target_network')
...@@ -138,12 +138,12 @@ class AtariPlayer(gym.Env): ...@@ -138,12 +138,12 @@ class AtariPlayer(gym.Env):
self.last_raw_screen = self._grab_raw_image() self.last_raw_screen = self._grab_raw_image()
self.ale.act(0) self.ale.act(0)
def _reset(self): def reset(self):
if self.ale.game_over(): if self.ale.game_over():
self._restart_episode() self._restart_episode()
return self._current_state() return self._current_state()
def _step(self, act): def step(self, act):
oldlives = self.ale.lives() oldlives = self.ale.lives()
r = 0 r = 0
for k in range(self.frame_skip): for k in range(self.frame_skip):
......
...@@ -8,6 +8,9 @@ from collections import deque ...@@ -8,6 +8,9 @@ from collections import deque
import gym import gym
from gym import spaces from gym import spaces
_v0, _v1 = gym.__version__.split('.')[:2]
assert int(_v0) > 0 or int(_v1) >= 10, gym.__version__
""" """
The following wrappers are copied or modified from openai/baselines: The following wrappers are copied or modified from openai/baselines:
...@@ -20,7 +23,7 @@ class MapState(gym.ObservationWrapper): ...@@ -20,7 +23,7 @@ class MapState(gym.ObservationWrapper):
gym.ObservationWrapper.__init__(self, env) gym.ObservationWrapper.__init__(self, env)
self._func = map_func self._func = map_func
def _observation(self, obs): def observation(self, obs):
return self._func(obs) return self._func(obs)
...@@ -32,22 +35,23 @@ class FrameStack(gym.Wrapper): ...@@ -32,22 +35,23 @@ class FrameStack(gym.Wrapper):
self.frames = deque([], maxlen=k) self.frames = deque([], maxlen=k)
shp = env.observation_space.shape shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2] chan = 1 if len(shp) == 2 else shp[2]
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k)) self.observation_space = spaces.Box(
low=0, high=255, shape=(shp[0], shp[1], chan * k), dtype=np.uint8)
def _reset(self): def reset(self):
"""Clear buffer and re-fill by duplicating the first observation.""" """Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset() ob = self.env.reset()
for _ in range(self.k - 1): for _ in range(self.k - 1):
self.frames.append(np.zeros_like(ob)) self.frames.append(np.zeros_like(ob))
self.frames.append(ob) self.frames.append(ob)
return self._observation() return self.observation()
def _step(self, action): def step(self, action):
ob, reward, done, info = self.env.step(action) ob, reward, done, info = self.env.step(action)
self.frames.append(ob) self.frames.append(ob)
return self._observation(), reward, done, info return self.observation(), reward, done, info
def _observation(self): def observation(self):
assert len(self.frames) == self.k assert len(self.frames) == self.k
if self.frames[-1].ndim == 2: if self.frames[-1].ndim == 2:
return np.stack(self.frames, axis=-1) return np.stack(self.frames, axis=-1)
...@@ -62,7 +66,7 @@ class _FireResetEnv(gym.Wrapper): ...@@ -62,7 +66,7 @@ class _FireResetEnv(gym.Wrapper):
assert env.unwrapped.get_action_meanings()[1] == 'FIRE' assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3 assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self): def reset(self):
self.env.reset() self.env.reset()
obs, _, done, _ = self.env.step(1) obs, _, done, _ = self.env.step(1)
if done: if done:
...@@ -72,6 +76,9 @@ class _FireResetEnv(gym.Wrapper): ...@@ -72,6 +76,9 @@ class _FireResetEnv(gym.Wrapper):
self.env.reset() self.env.reset()
return obs return obs
def step(self, action):
return self.env.step(action)
def FireResetEnv(env): def FireResetEnv(env):
if isinstance(env, gym.Wrapper): if isinstance(env, gym.Wrapper):
...@@ -88,7 +95,7 @@ class LimitLength(gym.Wrapper): ...@@ -88,7 +95,7 @@ class LimitLength(gym.Wrapper):
gym.Wrapper.__init__(self, env) gym.Wrapper.__init__(self, env)
self.k = k self.k = k
def _reset(self): def reset(self):
# This assumes that reset() will really reset the env. # This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset # If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold. # (e.g. end-of-life), the assumption doesn't hold.
...@@ -96,7 +103,7 @@ class LimitLength(gym.Wrapper): ...@@ -96,7 +103,7 @@ class LimitLength(gym.Wrapper):
self.cnt = 0 self.cnt = 0
return ob return ob
def _step(self, action): def step(self, action):
ob, r, done, info = self.env.step(action) ob, r, done, info = self.env.step(action)
self.cnt += 1 self.cnt += 1
if self.cnt == self.k: if self.cnt == self.k:
......
...@@ -18,7 +18,7 @@ from tensorpack.utils.utils import get_tqdm_kwargs ...@@ -18,7 +18,7 @@ from tensorpack.utils.utils import get_tqdm_kwargs
def play_one_episode(env, func, render=False): def play_one_episode(env, func, render=False):
def predict(s): def predict(s):
""" """
Map from observation to action, with 0.001 greedy. Map from observation to action, with 0.01 greedy.
""" """
act = func(s[None, :, :, :])[0][0].argmax() act = func(s[None, :, :, :])[0][0].argmax()
if random.random() < 0.01: if random.random() < 0.01:
...@@ -45,7 +45,7 @@ def play_n_episodes(player, predfunc, nr, render=False): ...@@ -45,7 +45,7 @@ def play_n_episodes(player, predfunc, nr, render=False):
print("{}/{}, score={}".format(k, nr, score)) print("{}/{}, score={}".format(k, nr, score))
def eval_with_funcs(predictors, nr_eval, get_player_fn): def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
""" """
Args: Args:
predictors ([PredictorBase]) predictors ([PredictorBase])
...@@ -67,7 +67,6 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn): ...@@ -67,7 +67,6 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
while not self.stopped(): while not self.stopped():
try: try:
score = play_one_episode(player, self.func) score = play_one_episode(player, self.func)
# print("Score, ", score)
except RuntimeError: except RuntimeError:
return return
self.queue_put_stoppable(self.q, score) self.queue_put_stoppable(self.q, score)
...@@ -80,17 +79,21 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn): ...@@ -80,17 +79,21 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
time.sleep(0.1) # avoid simulator bugs time.sleep(0.1) # avoid simulator bugs
stat = StatCounter() stat = StatCounter()
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()): def fetch():
r = q.get() r = q.get()
stat.feed(r) stat.feed(r)
if verbose:
logger.info("Score: {}".format(r))
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
fetch()
logger.info("Waiting for all the workers to finish the last run...") logger.info("Waiting for all the workers to finish the last run...")
for k in threads: for k in threads:
k.stop() k.stop()
for k in threads: for k in threads:
k.join() k.join()
while q.qsize(): while q.qsize():
r = q.get() fetch()
stat.feed(r)
if stat.count > 0: if stat.count > 0:
return (stat.average, stat.max) return (stat.average, stat.max)
...@@ -100,11 +103,13 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn): ...@@ -100,11 +103,13 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
def eval_model_multithread(pred, nr_eval, get_player_fn): def eval_model_multithread(pred, nr_eval, get_player_fn):
""" """
Args: Args:
pred (OfflinePredictor): state -> Qvalue pred (OfflinePredictor): state -> [#action]
""" """
NR_PROC = min(multiprocessing.cpu_count() // 2, 8) NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
with pred.sess.as_default(): with pred.sess.as_default():
mean, max = eval_with_funcs([pred] * NR_PROC, nr_eval, get_player_fn) mean, max = eval_with_funcs(
[pred] * NR_PROC, nr_eval,
get_player_fn, verbose=True)
logger.info("Average Score: {}; Max Score: {}".format(mean, max)) logger.info("Average Score: {}; Max Score: {}".format(mean, max))
......
...@@ -95,6 +95,11 @@ class ModelDescBase(object): ...@@ -95,6 +95,11 @@ class ModelDescBase(object):
Args: Args:
args ([tf.Tensor]): tensors that matches the list of args ([tf.Tensor]): tensors that matches the list of
:class:`InputDesc` defined by ``_get_inputs``. :class:`InputDesc` defined by ``_get_inputs``.
Returns:
In general it returns nothing, but a subclass (e.g.
:class:`ModelDesc` may require it to return necessary information
to build the trainer.
""" """
if len(args) == 1: if len(args) == 1:
arg = args[0] arg = args[0]
...@@ -124,18 +129,16 @@ class ModelDescBase(object): ...@@ -124,18 +129,16 @@ class ModelDescBase(object):
class ModelDesc(ModelDescBase): class ModelDesc(ModelDescBase):
""" """
A ModelDesc with single cost and single optimizer. A ModelDesc with **single cost** and **single optimizer**.
It contains information about InputDesc, how to get cost, and how to get optimizer. It contains information about InputDesc, how to get cost, and how to get optimizer.
""" """
def get_cost(self): def get_cost(self):
""" """
Return the cost tensor in the graph. Return the cost tensor to optimize on.
It calls :meth:`ModelDesc._get_cost()` which by default returns
``self.cost``. You can override :meth:`_get_cost()` if needed.
This function also applies the collection This function takes the cost tensor defined by :meth:`build_graph`,
and applies the collection
``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically. ``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically.
""" """
cost = self._get_cost() cost = self._get_cost()
...@@ -165,6 +168,9 @@ class ModelDesc(ModelDescBase): ...@@ -165,6 +168,9 @@ class ModelDesc(ModelDescBase):
raise NotImplementedError() raise NotImplementedError()
def _build_graph_get_cost(self, *inputs): def _build_graph_get_cost(self, *inputs):
"""
Used by trainers to get the final cost for optimization.
"""
self.build_graph(*inputs) self.build_graph(*inputs)
return self.get_cost() return self.get_cost()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment