Commit 52aae61a authored by Yuxin Wu's avatar Yuxin Wu

gymenv. fix gradproc. auto-restart limitlength

parent ee227da4
...@@ -13,12 +13,12 @@ See some interesting [examples](examples) to learn about the framework: ...@@ -13,12 +13,12 @@ See some interesting [examples](examples) to learn about the framework:
## Features: ## Features:
Focus on modularity. You just have to define the following three components to start a training: You need to abstract your training task into three components:
1. The model, or the graph. `models/` has some scoped abstraction of common models. 1. Model, or graph. `models/` has some scoped abstraction of common models.
`LinearWrap` and `argscope` makes large models look simpler. `LinearWrap` and `argscope` makes large models look simpler.
2. The data. tensorpack allows and encourages complex data processing. 2. Data. tensorpack allows and encourages complex data processing.
+ All data producer has an unified `DataFlow` interface, allowing them to be composed to perform complex preprocessing. + All data producer has an unified `DataFlow` interface, allowing them to be composed to perform complex preprocessing.
+ Use Python to easily handle your own data format, yet still keep a good training speed thanks to multiprocess prefetch & TF Queue prefetch. + Use Python to easily handle your own data format, yet still keep a good training speed thanks to multiprocess prefetch & TF Queue prefetch.
...@@ -30,7 +30,7 @@ Focus on modularity. You just have to define the following three components to s ...@@ -30,7 +30,7 @@ Focus on modularity. You just have to define the following three components to s
+ Run inference on a test dataset + Run inference on a test dataset
With the above components defined, tensorpack trainer will run the training iterations for you. With the above components defined, tensorpack trainer will run the training iterations for you.
Multi-GPU training is ready to use by simply changing the trainer. Multi-GPU training is ready to use by simply switching the trainer.
## Dependencies: ## Dependencies:
......
...@@ -18,8 +18,10 @@ from tensorpack.utils.concurrency import * ...@@ -18,8 +18,10 @@ from tensorpack.utils.concurrency import *
from tensorpack.tfutils import symbolic_functions as symbf from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.RL import * from tensorpack.RL import *
import common import common
from common import play_model, Evaluator, eval_model_multithread from common import play_model, Evaluator, eval_model_multithread
from atari import AtariPlayer
BATCH_SIZE = 64 BATCH_SIZE = 64
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
...@@ -54,7 +56,7 @@ def get_player(viz=False, train=False): ...@@ -54,7 +56,7 @@ def get_player(viz=False, train=False):
if not train: if not train:
pl = HistoryFramePlayer(pl, FRAME_HISTORY) pl = HistoryFramePlayer(pl, FRAME_HISTORY)
pl = PreventStuckPlayer(pl, 30, 1) pl = PreventStuckPlayer(pl, 30, 1)
pl = LimitLengthPlayer(pl, 20000) pl = LimitLengthPlayer(pl, 30000)
return pl return pl
common.get_player = get_player # so that eval functions in common can use the player common.get_player = get_player # so that eval functions in common can use the player
......
...@@ -10,15 +10,12 @@ from collections import deque ...@@ -10,15 +10,12 @@ from collections import deque
import threading import threading
import six import six
from six.moves import range from six.moves import range
from ..utils import get_rng, logger, memoized, get_dataset_path from tensorpack.utils import get_rng, logger, memoized, get_dataset_path
from ..utils.stat import StatCounter from tensorpack.utils.stat import StatCounter
from .envbase import RLEnvironment, DiscreteActionSpace from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace
try: from ale_python_interface import ALEInterface
from ale_python_interface import ALEInterface
except ImportError:
logger.warn("Cannot import ale_python_interface, Atari won't be available.")
__all__ = ['AtariPlayer'] __all__ = ['AtariPlayer']
......
...@@ -42,7 +42,7 @@ class PreventStuckPlayer(ProxyPlayer): ...@@ -42,7 +42,7 @@ class PreventStuckPlayer(ProxyPlayer):
class LimitLengthPlayer(ProxyPlayer): class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode. """ Limit the total number of actions in an episode.
Does auto-reset, but doesn't auto-restart the underlying player. Will auto restart the underlying player on timeout
""" """
def __init__(self, player, limit): def __init__(self, player, limit):
super(LimitLengthPlayer, self).__init__(player) super(LimitLengthPlayer, self).__init__(player)
...@@ -55,11 +55,12 @@ class LimitLengthPlayer(ProxyPlayer): ...@@ -55,11 +55,12 @@ class LimitLengthPlayer(ProxyPlayer):
if self.cnt >= self.limit: if self.cnt >= self.limit:
isOver = True isOver = True
if isOver: if isOver:
self.cnt = 0 self.finish_episode()
self.restart_episode()
return (r, isOver) return (r, isOver)
def restart_episode(self): def restart_episode(self):
super(LimitLengthPlayer, self).restart_episode() self.player.restart_episode()
self.cnt = 0 self.cnt = 0
class AutoRestartPlayer(ProxyPlayer): class AutoRestartPlayer(ProxyPlayer):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: gymenv.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
try:
import gym
except ImportError:
logger.warn("Cannot import gym. GymEnv won't be available.")
import time
from ..utils import logger
from ..utils.fs import *
from ..utils.stat import *
from .envbase import RLEnvironment, DiscreteActionSpace
class GymEnv(RLEnvironment):
"""
An OpenAI/gym wrapper. Will auto restart.
"""
def __init__(self, name, dumpdir=None, viz=False):
self.gymenv = gym.make(name)
#if dumpdir:
#mkdir_p(dumpdir)
#self.gymenv.monitor.start(dumpdir, force=True, seed=0)
self.reset_stat()
self.rwd_counter = StatCounter()
self.restart_episode()
self.viz = viz
def restart_episode(self):
self.rwd_counter.reset()
self._ob = self.gymenv.reset()
def finish_episode(self):
self.stats['score'].append(self.rwd_counter.sum)
def current_state(self):
if self.viz:
self.gymenv.render()
time.sleep(self.viz)
return self._ob
def action(self, act):
self._ob, r, isOver, info = self.gymenv.step(act)
self.rwd_counter.feed(r)
if isOver:
self.finish_episode()
self.restart_episode()
return r, isOver
def get_action_space(self):
spc = self.gymenv.action_space
assert isinstance(spc, gym.spaces.discrete.Discrete)
return DiscreteActionSpace(spc.n)
if __name__ == '__main__':
env = GymEnv('Breakout-v0', viz=0.1)
num = env.get_action_space().num_actions()
from ..utils import *
rng = get_rng(num)
while True:
act = rng.choice(range(num))
#print act
r, o = env.action(act)
env.current_state()
if r != 0 or o:
print r, o
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import tensorflow as tf import tensorflow as tf
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import re import re
import inspect
from ..utils import logger from ..utils import logger
from .symbolic_functions import rms from .symbolic_functions import rms
from .summary import add_moving_summary from .summary import add_moving_summary
...@@ -37,11 +38,19 @@ class MapGradient(GradientProcessor): ...@@ -37,11 +38,19 @@ class MapGradient(GradientProcessor):
""" """
def __init__(self, func, regex='.*'): def __init__(self, func, regex='.*'):
""" """
:param func: takes a (grad, var) pair and returns a grad. If return None, the :param func: takes a grad or (grad, var) pair and returns a grad. If return None, the
gradient is discarded. gradient is discarded.
:param regex: used to match variables. default to match all variables. :param regex: used to match variables. default to match all variables.
""" """
self.func = func args = inspect.getargspec(func).args
arg_num = len(args) - inspect.ismethod(func)
assert arg_num in [1, 2], \
"The function must take 1 or 2 arguments! ({})".format(args)
if arg_num == 1:
self.func = lambda grad, var: func(grad)
else:
self.func = func
if not regex.endswith('$'): if not regex.endswith('$'):
regex = regex + '$' regex = regex + '$'
self.regex = regex self.regex = regex
......
...@@ -105,8 +105,8 @@ class SaverRestore(SessionInit): ...@@ -105,8 +105,8 @@ class SaverRestore(SessionInit):
def _get_vars_to_restore_multimap(self, vars_available): def _get_vars_to_restore_multimap(self, vars_available):
""" """
Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaible names available in the checkpoint, for existence checking :param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore
""" """
vars_to_restore = tf.all_variables() vars_to_restore = tf.all_variables()
var_dict = defaultdict(list) var_dict = defaultdict(list)
...@@ -114,12 +114,11 @@ class SaverRestore(SessionInit): ...@@ -114,12 +114,11 @@ class SaverRestore(SessionInit):
for v in vars_to_restore: for v in vars_to_restore:
name = v.op.name name = v.op.name
if 'towerp' in name: if 'towerp' in name:
logger.warn("Variable {} in prediction tower shouldn't exist.".format(v.name)) logger.error("No variable should be under 'towerp' name scope".format(v.name))
# don't overwrite anything in the current prediction graph # don't overwrite anything in the current prediction graph
continue continue
if 'tower' in name: if 'tower' in name:
new_name = re.sub('tower[p0-9]+/', '', name) name = re.sub('tower[p0-9]+/', '', name)
name = new_name
if self.prefix and name.startswith(self.prefix): if self.prefix and name.startswith(self.prefix):
name = name[len(self.prefix)+1:] name = name[len(self.prefix)+1:]
if name in vars_available: if name in vars_available:
...@@ -127,11 +126,11 @@ class SaverRestore(SessionInit): ...@@ -127,11 +126,11 @@ class SaverRestore(SessionInit):
chkpt_vars_used.add(name) chkpt_vars_used.add(name)
#vars_available.remove(name) #vars_available.remove(name)
else: else:
logger.warn("Variable {} not found in checkpoint!".format(v.op.name)) logger.warn("Variable {} in the graph not found in checkpoint!".format(v.op.name))
if len(chkpt_vars_used) < len(vars_available): if len(chkpt_vars_used) < len(vars_available):
unused = vars_available - chkpt_vars_used unused = vars_available - chkpt_vars_used
for name in unused: for name in unused:
logger.warn("Variable {} in checkpoint doesn't exist in the graph!".format(name)) logger.warn("Variable {} in checkpoint not found in the graph!".format(name))
return var_dict return var_dict
class ParamRestore(SessionInit): class ParamRestore(SessionInit):
...@@ -155,9 +154,9 @@ class ParamRestore(SessionInit): ...@@ -155,9 +154,9 @@ class ParamRestore(SessionInit):
logger.info("Params to restore: {}".format( logger.info("Params to restore: {}".format(
', '.join(map(str, intersect)))) ', '.join(map(str, intersect))))
for k in variable_names - param_names: for k in variable_names - param_names:
logger.warn("Variable {} in the graph not getting restored!".format(k)) logger.warn("Variable {} in the graph not found in the dict!".format(k))
for k in param_names - variable_names: for k in param_names - variable_names:
logger.warn("Variable {} in the dict not found in this graph!".format(k)) logger.warn("Variable {} in the dict not found in the graph!".format(k))
upd = SessionUpdate(sess, [v for v in variables if v.name in intersect]) upd = SessionUpdate(sess, [v for v in variables if v.name in intersect])
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
import os, sys import os, sys
from six.moves import urllib from six.moves import urllib
__all__ = ['mkdir_p', 'download']
def mkdir_p(dirname): def mkdir_p(dirname):
""" make a dir recursively, but do nothing if the dir exists""" """ make a dir recursively, but do nothing if the dir exists"""
assert dirname is not None assert dirname is not None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment