Commit 55098813 authored by Yuxin Wu's avatar Yuxin Wu

Upgrade gym

parent 4bc0c748
......@@ -362,7 +362,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
# include_init_with_doc doesn't work well for decorated init
# https://github.com/sphinx-doc/sphinx/issues/4258
return False
# hide deprecated stuff
# Hide some names that are deprecated or not intended to be used
if name in [
# deprecated stuff:
'GaussianDeform',
......
This diff is collapsed.
......@@ -54,7 +54,7 @@ ENV_NAME = None
def get_player(train=False, dumpdir=None):
env = gym.make(ENV_NAME)
if dumpdir:
env = gym.wrappers.Monitor(env, dumpdir)
env = gym.wrappers.Monitor(env, dumpdir, video_callable=lambda _: True)
env = FireResetEnv(env)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
env = FrameStack(env, 4)
......@@ -272,7 +272,7 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model')
parser.add_argument('--env', help='env', required=True)
parser.add_argument('--task', help='task to perform',
choices=['play', 'eval', 'train', 'gen_submit'], default='train')
choices=['play', 'eval', 'train', 'dump_video'], default='train')
parser.add_argument('--output', help='output directory for submission', default='output_dir')
parser.add_argument('--episode', help='number of episode to eval', default=100, type=int)
args = parser.parse_args()
......@@ -297,10 +297,9 @@ if __name__ == '__main__':
args.episode, render=True)
elif args.task == 'eval':
eval_model_multithread(pred, args.episode, get_player)
elif args.task == 'gen_submit':
elif args.task == 'dump_video':
play_n_episodes(
get_player(train=False, dumpdir=args.output),
pred, args.episode)
# gym.upload(args.output, api_key='xxx')
else:
train()
......@@ -9,7 +9,7 @@ import tensorpack
from tensorpack import ModelDesc, InputDesc
from tensorpack.utils import logger
from tensorpack.tfutils import (
summary, get_current_tower_context, optimizer, gradproc)
varreplace, summary, get_current_tower_context, optimizer, gradproc)
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
assert tensorpack.tfutils.common.get_tf_version_number() >= 1.2
......@@ -60,7 +60,7 @@ class Model(ModelDesc):
self.predict_value, 1), name='predict_reward')
summary.add_moving_summary(max_pred_reward)
with tf.variable_scope('target'):
with tf.variable_scope('target'), varreplace.freeze_variables(skip_collection=True):
targetQ_predict_value = self.get_DQN_prediction(next_state) # NxA
if self.method != 'Double':
......@@ -96,6 +96,6 @@ class Model(ModelDesc):
target_name = v.op.name
if target_name.startswith('target'):
new_name = target_name.replace('target/', '')
logger.info("{} <- {}".format(target_name, new_name))
logger.info("Target Network Update: {} <- {}".format(target_name, new_name))
ops.append(v.assign(G.get_tensor_by_name(new_name + ':0')))
return tf.group(*ops, name='update_target_network')
......@@ -138,12 +138,12 @@ class AtariPlayer(gym.Env):
self.last_raw_screen = self._grab_raw_image()
self.ale.act(0)
def _reset(self):
def reset(self):
if self.ale.game_over():
self._restart_episode()
return self._current_state()
def _step(self, act):
def step(self, act):
oldlives = self.ale.lives()
r = 0
for k in range(self.frame_skip):
......
......@@ -8,6 +8,9 @@ from collections import deque
import gym
from gym import spaces
_v0, _v1 = gym.__version__.split('.')[:2]
assert int(_v0) > 0 or int(_v1) >= 10, gym.__version__
"""
The following wrappers are copied or modified from openai/baselines:
......@@ -20,7 +23,7 @@ class MapState(gym.ObservationWrapper):
gym.ObservationWrapper.__init__(self, env)
self._func = map_func
def _observation(self, obs):
def observation(self, obs):
return self._func(obs)
......@@ -32,22 +35,23 @@ class FrameStack(gym.Wrapper):
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
chan = 1 if len(shp) == 2 else shp[2]
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], chan * k))
self.observation_space = spaces.Box(
low=0, high=255, shape=(shp[0], shp[1], chan * k), dtype=np.uint8)
def _reset(self):
def reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k - 1):
self.frames.append(np.zeros_like(ob))
self.frames.append(ob)
return self._observation()
return self.observation()
def _step(self, action):
def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._observation(), reward, done, info
return self.observation(), reward, done, info
def _observation(self):
def observation(self):
assert len(self.frames) == self.k
if self.frames[-1].ndim == 2:
return np.stack(self.frames, axis=-1)
......@@ -62,7 +66,7 @@ class _FireResetEnv(gym.Wrapper):
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def _reset(self):
def reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
......@@ -72,6 +76,9 @@ class _FireResetEnv(gym.Wrapper):
self.env.reset()
return obs
def step(self, action):
return self.env.step(action)
def FireResetEnv(env):
if isinstance(env, gym.Wrapper):
......@@ -88,7 +95,7 @@ class LimitLength(gym.Wrapper):
gym.Wrapper.__init__(self, env)
self.k = k
def _reset(self):
def reset(self):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
......@@ -96,7 +103,7 @@ class LimitLength(gym.Wrapper):
self.cnt = 0
return ob
def _step(self, action):
def step(self, action):
ob, r, done, info = self.env.step(action)
self.cnt += 1
if self.cnt == self.k:
......
......@@ -18,7 +18,7 @@ from tensorpack.utils.utils import get_tqdm_kwargs
def play_one_episode(env, func, render=False):
def predict(s):
"""
Map from observation to action, with 0.001 greedy.
Map from observation to action, with 0.01 greedy.
"""
act = func(s[None, :, :, :])[0][0].argmax()
if random.random() < 0.01:
......@@ -45,7 +45,7 @@ def play_n_episodes(player, predfunc, nr, render=False):
print("{}/{}, score={}".format(k, nr, score))
def eval_with_funcs(predictors, nr_eval, get_player_fn):
def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
"""
Args:
predictors ([PredictorBase])
......@@ -67,7 +67,6 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
while not self.stopped():
try:
score = play_one_episode(player, self.func)
# print("Score, ", score)
except RuntimeError:
return
self.queue_put_stoppable(self.q, score)
......@@ -80,17 +79,21 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
time.sleep(0.1) # avoid simulator bugs
stat = StatCounter()
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
def fetch():
r = q.get()
stat.feed(r)
if verbose:
logger.info("Score: {}".format(r))
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
fetch()
logger.info("Waiting for all the workers to finish the last run...")
for k in threads:
k.stop()
for k in threads:
k.join()
while q.qsize():
r = q.get()
stat.feed(r)
fetch()
if stat.count > 0:
return (stat.average, stat.max)
......@@ -100,11 +103,13 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
def eval_model_multithread(pred, nr_eval, get_player_fn):
"""
Args:
pred (OfflinePredictor): state -> Qvalue
pred (OfflinePredictor): state -> [#action]
"""
NR_PROC = min(multiprocessing.cpu_count() // 2, 8)
with pred.sess.as_default():
mean, max = eval_with_funcs([pred] * NR_PROC, nr_eval, get_player_fn)
mean, max = eval_with_funcs(
[pred] * NR_PROC, nr_eval,
get_player_fn, verbose=True)
logger.info("Average Score: {}; Max Score: {}".format(mean, max))
......
......@@ -95,6 +95,11 @@ class ModelDescBase(object):
Args:
args ([tf.Tensor]): tensors that matches the list of
:class:`InputDesc` defined by ``_get_inputs``.
Returns:
In general it returns nothing, but a subclass (e.g.
:class:`ModelDesc` may require it to return necessary information
to build the trainer.
"""
if len(args) == 1:
arg = args[0]
......@@ -124,18 +129,16 @@ class ModelDescBase(object):
class ModelDesc(ModelDescBase):
"""
A ModelDesc with single cost and single optimizer.
A ModelDesc with **single cost** and **single optimizer**.
It contains information about InputDesc, how to get cost, and how to get optimizer.
"""
def get_cost(self):
"""
Return the cost tensor in the graph.
It calls :meth:`ModelDesc._get_cost()` which by default returns
``self.cost``. You can override :meth:`_get_cost()` if needed.
Return the cost tensor to optimize on.
This function also applies the collection
This function takes the cost tensor defined by :meth:`build_graph`,
and applies the collection
``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically.
"""
cost = self._get_cost()
......@@ -165,6 +168,9 @@ class ModelDesc(ModelDescBase):
raise NotImplementedError()
def _build_graph_get_cost(self, *inputs):
"""
Used by trainers to get the final cost for optimization.
"""
self.build_graph(*inputs)
return self.get_cost()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment