Commit 52aae61a authored by Yuxin Wu's avatar Yuxin Wu

gymenv. fix gradproc. auto-restart limitlength

parent ee227da4
......@@ -13,12 +13,12 @@ See some interesting [examples](examples) to learn about the framework:
## Features:
Focus on modularity. You just have to define the following three components to start a training:
You need to abstract your training task into three components:
1. The model, or the graph. `models/` has some scoped abstraction of common models.
1. Model, or graph. `models/` has some scoped abstraction of common models.
`LinearWrap` and `argscope` makes large models look simpler.
2. The data. tensorpack allows and encourages complex data processing.
2. Data. tensorpack allows and encourages complex data processing.
+ All data producer has an unified `DataFlow` interface, allowing them to be composed to perform complex preprocessing.
+ Use Python to easily handle your own data format, yet still keep a good training speed thanks to multiprocess prefetch & TF Queue prefetch.
......@@ -30,7 +30,7 @@ Focus on modularity. You just have to define the following three components to s
+ Run inference on a test dataset
With the above components defined, tensorpack trainer will run the training iterations for you.
Multi-GPU training is ready to use by simply changing the trainer.
Multi-GPU training is ready to use by simply switching the trainer.
## Dependencies:
......
......@@ -18,8 +18,10 @@ from tensorpack.utils.concurrency import *
from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.RL import *
import common
from common import play_model, Evaluator, eval_model_multithread
from atari import AtariPlayer
BATCH_SIZE = 64
IMAGE_SIZE = (84, 84)
......@@ -54,7 +56,7 @@ def get_player(viz=False, train=False):
if not train:
pl = HistoryFramePlayer(pl, FRAME_HISTORY)
pl = PreventStuckPlayer(pl, 30, 1)
pl = LimitLengthPlayer(pl, 20000)
pl = LimitLengthPlayer(pl, 30000)
return pl
common.get_player = get_player # so that eval functions in common can use the player
......
......@@ -10,15 +10,12 @@ from collections import deque
import threading
import six
from six.moves import range
from ..utils import get_rng, logger, memoized, get_dataset_path
from ..utils.stat import StatCounter
from tensorpack.utils import get_rng, logger, memoized, get_dataset_path
from tensorpack.utils.stat import StatCounter
from .envbase import RLEnvironment, DiscreteActionSpace
from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace
try:
from ale_python_interface import ALEInterface
except ImportError:
logger.warn("Cannot import ale_python_interface, Atari won't be available.")
from ale_python_interface import ALEInterface
__all__ = ['AtariPlayer']
......
......@@ -42,7 +42,7 @@ class PreventStuckPlayer(ProxyPlayer):
class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode.
Does auto-reset, but doesn't auto-restart the underlying player.
Will auto restart the underlying player on timeout
"""
def __init__(self, player, limit):
super(LimitLengthPlayer, self).__init__(player)
......@@ -55,11 +55,12 @@ class LimitLengthPlayer(ProxyPlayer):
if self.cnt >= self.limit:
isOver = True
if isOver:
self.cnt = 0
self.finish_episode()
self.restart_episode()
return (r, isOver)
def restart_episode(self):
super(LimitLengthPlayer, self).restart_episode()
self.player.restart_episode()
self.cnt = 0
class AutoRestartPlayer(ProxyPlayer):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: gymenv.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
try:
import gym
except ImportError:
logger.warn("Cannot import gym. GymEnv won't be available.")
import time
from ..utils import logger
from ..utils.fs import *
from ..utils.stat import *
from .envbase import RLEnvironment, DiscreteActionSpace
class GymEnv(RLEnvironment):
"""
An OpenAI/gym wrapper. Will auto restart.
"""
def __init__(self, name, dumpdir=None, viz=False):
self.gymenv = gym.make(name)
#if dumpdir:
#mkdir_p(dumpdir)
#self.gymenv.monitor.start(dumpdir, force=True, seed=0)
self.reset_stat()
self.rwd_counter = StatCounter()
self.restart_episode()
self.viz = viz
def restart_episode(self):
self.rwd_counter.reset()
self._ob = self.gymenv.reset()
def finish_episode(self):
self.stats['score'].append(self.rwd_counter.sum)
def current_state(self):
if self.viz:
self.gymenv.render()
time.sleep(self.viz)
return self._ob
def action(self, act):
self._ob, r, isOver, info = self.gymenv.step(act)
self.rwd_counter.feed(r)
if isOver:
self.finish_episode()
self.restart_episode()
return r, isOver
def get_action_space(self):
spc = self.gymenv.action_space
assert isinstance(spc, gym.spaces.discrete.Discrete)
return DiscreteActionSpace(spc.n)
if __name__ == '__main__':
env = GymEnv('Breakout-v0', viz=0.1)
num = env.get_action_space().num_actions()
from ..utils import *
rng = get_rng(num)
while True:
act = rng.choice(range(num))
#print act
r, o = env.action(act)
env.current_state()
if r != 0 or o:
print r, o
......@@ -6,6 +6,7 @@
import tensorflow as tf
from abc import ABCMeta, abstractmethod
import re
import inspect
from ..utils import logger
from .symbolic_functions import rms
from .summary import add_moving_summary
......@@ -37,11 +38,19 @@ class MapGradient(GradientProcessor):
"""
def __init__(self, func, regex='.*'):
"""
:param func: takes a (grad, var) pair and returns a grad. If return None, the
:param func: takes a grad or (grad, var) pair and returns a grad. If return None, the
gradient is discarded.
:param regex: used to match variables. default to match all variables.
"""
self.func = func
args = inspect.getargspec(func).args
arg_num = len(args) - inspect.ismethod(func)
assert arg_num in [1, 2], \
"The function must take 1 or 2 arguments! ({})".format(args)
if arg_num == 1:
self.func = lambda grad, var: func(grad)
else:
self.func = func
if not regex.endswith('$'):
regex = regex + '$'
self.regex = regex
......
......@@ -105,8 +105,8 @@ class SaverRestore(SessionInit):
def _get_vars_to_restore_multimap(self, vars_available):
"""
Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore
"""
vars_to_restore = tf.all_variables()
var_dict = defaultdict(list)
......@@ -114,12 +114,11 @@ class SaverRestore(SessionInit):
for v in vars_to_restore:
name = v.op.name
if 'towerp' in name:
logger.warn("Variable {} in prediction tower shouldn't exist.".format(v.name))
logger.error("No variable should be under 'towerp' name scope".format(v.name))
# don't overwrite anything in the current prediction graph
continue
if 'tower' in name:
new_name = re.sub('tower[p0-9]+/', '', name)
name = new_name
name = re.sub('tower[p0-9]+/', '', name)
if self.prefix and name.startswith(self.prefix):
name = name[len(self.prefix)+1:]
if name in vars_available:
......@@ -127,11 +126,11 @@ class SaverRestore(SessionInit):
chkpt_vars_used.add(name)
#vars_available.remove(name)
else:
logger.warn("Variable {} not found in checkpoint!".format(v.op.name))
logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name))
if len(chkpt_vars_used) < len(vars_available):
unused = vars_available - chkpt_vars_used
for name in unused:
logger.warn("Variable {} in checkpoint doesn't exist in the graph!".format(name))
logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
return var_dict
class ParamRestore(SessionInit):
......@@ -155,9 +154,9 @@ class ParamRestore(SessionInit):
logger.info("Params to restore: {}".format(
', '.join(map(str, intersect))))
for k in variable_names - param_names:
logger.warn("Variable {} in the graph not getting restored!".format(k))
logger.warn("Variable {} in the graph not found in the dict!".format(k))
for k in param_names - variable_names:
logger.warn("Variable {} in the dict not found in this graph!".format(k))
logger.warn("Variable {} in the dict not found in the graph!".format(k))
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
......
......@@ -6,6 +6,8 @@
import os, sys
from six.moves import urllib
__all__ = ['mkdir_p', 'download']
def mkdir_p(dirname):
""" make a dir recursively, but do nothing if the dir exists"""
assert dirname is not None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment