Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
52aae61a
Commit
52aae61a
authored
Aug 04, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
gymenv. fix gradproc. auto-restart limitlength
parent
ee227da4
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
107 additions
and
25 deletions
+107
-25
README.md
README.md
+4
-4
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+3
-1
examples/Atari2600/atari.py
examples/Atari2600/atari.py
+4
-7
tensorpack/RL/common.py
tensorpack/RL/common.py
+4
-3
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+72
-0
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+11
-2
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+7
-8
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+2
-0
No files found.
README.md
View file @
52aae61a
...
...
@@ -13,12 +13,12 @@ See some interesting [examples](examples) to learn about the framework:
## Features:
Focus on modularity. You just have to define the following three components to start a training
:
You need to abstract your training task into three components
:
1.
The model, or the
graph.
`models/`
has some scoped abstraction of common models.
1.
Model, or
graph.
`models/`
has some scoped abstraction of common models.
`LinearWrap`
and
`argscope`
makes large models look simpler.
2.
The d
ata. tensorpack allows and encourages complex data processing.
2.
D
ata. tensorpack allows and encourages complex data processing.
+ All data producer has an unified `DataFlow` interface, allowing them to be composed to perform complex preprocessing.
+ Use Python to easily handle your own data format, yet still keep a good training speed thanks to multiprocess prefetch & TF Queue prefetch.
...
...
@@ -30,7 +30,7 @@ Focus on modularity. You just have to define the following three components to s
+
Run inference on a test dataset
With the above components defined, tensorpack trainer will run the training iterations for you.
Multi-GPU training is ready to use by simply
chang
ing the trainer.
Multi-GPU training is ready to use by simply
switch
ing the trainer.
## Dependencies:
...
...
examples/Atari2600/DQN.py
View file @
52aae61a
...
...
@@ -18,8 +18,10 @@ from tensorpack.utils.concurrency import *
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.RL
import
*
import
common
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
atari
import
AtariPlayer
BATCH_SIZE
=
64
IMAGE_SIZE
=
(
84
,
84
)
...
...
@@ -54,7 +56,7 @@ def get_player(viz=False, train=False):
if
not
train
:
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
LimitLengthPlayer
(
pl
,
2
0000
)
pl
=
LimitLengthPlayer
(
pl
,
3
0000
)
return
pl
common
.
get_player
=
get_player
# so that eval functions in common can use the player
...
...
tensorpack/RL
/atari.py
→
examples/Atari2600
/atari.py
View file @
52aae61a
...
...
@@ -10,15 +10,12 @@ from collections import deque
import
threading
import
six
from
six.moves
import
range
from
.
.utils
import
get_rng
,
logger
,
memoized
,
get_dataset_path
from
.
.utils.stat
import
StatCounter
from
tensorpack
.utils
import
get_rng
,
logger
,
memoized
,
get_dataset_path
from
tensorpack
.utils.stat
import
StatCounter
from
.envbase
import
RLEnvironment
,
DiscreteActionSpace
from
tensorpack.RL
.envbase
import
RLEnvironment
,
DiscreteActionSpace
try
:
from
ale_python_interface
import
ALEInterface
except
ImportError
:
logger
.
warn
(
"Cannot import ale_python_interface, Atari won't be available."
)
from
ale_python_interface
import
ALEInterface
__all__
=
[
'AtariPlayer'
]
...
...
tensorpack/RL/common.py
View file @
52aae61a
...
...
@@ -42,7 +42,7 @@ class PreventStuckPlayer(ProxyPlayer):
class
LimitLengthPlayer
(
ProxyPlayer
):
""" Limit the total number of actions in an episode.
Does auto-reset, but doesn't auto-restart the underlying player.
Will auto restart the underlying player on timeout
"""
def
__init__
(
self
,
player
,
limit
):
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
...
...
@@ -55,11 +55,12 @@ class LimitLengthPlayer(ProxyPlayer):
if
self
.
cnt
>=
self
.
limit
:
isOver
=
True
if
isOver
:
self
.
cnt
=
0
self
.
finish_episode
()
self
.
restart_episode
()
return
(
r
,
isOver
)
def
restart_episode
(
self
):
s
uper
(
LimitLengthPlayer
,
self
)
.
restart_episode
()
s
elf
.
player
.
restart_episode
()
self
.
cnt
=
0
class
AutoRestartPlayer
(
ProxyPlayer
):
...
...
tensorpack/RL/gymenv.py
0 → 100644
View file @
52aae61a
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: gymenv.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
try
:
import
gym
except
ImportError
:
logger
.
warn
(
"Cannot import gym. GymEnv won't be available."
)
import
time
from
..utils
import
logger
from
..utils.fs
import
*
from
..utils.stat
import
*
from
.envbase
import
RLEnvironment
,
DiscreteActionSpace
class
GymEnv
(
RLEnvironment
):
"""
An OpenAI/gym wrapper. Will auto restart.
"""
def
__init__
(
self
,
name
,
dumpdir
=
None
,
viz
=
False
):
self
.
gymenv
=
gym
.
make
(
name
)
#if dumpdir:
#mkdir_p(dumpdir)
#self.gymenv.monitor.start(dumpdir, force=True, seed=0)
self
.
reset_stat
()
self
.
rwd_counter
=
StatCounter
()
self
.
restart_episode
()
self
.
viz
=
viz
def
restart_episode
(
self
):
self
.
rwd_counter
.
reset
()
self
.
_ob
=
self
.
gymenv
.
reset
()
def
finish_episode
(
self
):
self
.
stats
[
'score'
]
.
append
(
self
.
rwd_counter
.
sum
)
def
current_state
(
self
):
if
self
.
viz
:
self
.
gymenv
.
render
()
time
.
sleep
(
self
.
viz
)
return
self
.
_ob
def
action
(
self
,
act
):
self
.
_ob
,
r
,
isOver
,
info
=
self
.
gymenv
.
step
(
act
)
self
.
rwd_counter
.
feed
(
r
)
if
isOver
:
self
.
finish_episode
()
self
.
restart_episode
()
return
r
,
isOver
def
get_action_space
(
self
):
spc
=
self
.
gymenv
.
action_space
assert
isinstance
(
spc
,
gym
.
spaces
.
discrete
.
Discrete
)
return
DiscreteActionSpace
(
spc
.
n
)
if
__name__
==
'__main__'
:
env
=
GymEnv
(
'Breakout-v0'
,
viz
=
0.1
)
num
=
env
.
get_action_space
()
.
num_actions
()
from
..utils
import
*
rng
=
get_rng
(
num
)
while
True
:
act
=
rng
.
choice
(
range
(
num
))
#print act
r
,
o
=
env
.
action
(
act
)
env
.
current_state
()
if
r
!=
0
or
o
:
print
r
,
o
tensorpack/tfutils/gradproc.py
View file @
52aae61a
...
...
@@ -6,6 +6,7 @@
import
tensorflow
as
tf
from
abc
import
ABCMeta
,
abstractmethod
import
re
import
inspect
from
..utils
import
logger
from
.symbolic_functions
import
rms
from
.summary
import
add_moving_summary
...
...
@@ -37,11 +38,19 @@ class MapGradient(GradientProcessor):
"""
def
__init__
(
self
,
func
,
regex
=
'.*'
):
"""
:param func: takes a (grad, var) pair and returns a grad. If return None, the
:param func: takes a
grad or
(grad, var) pair and returns a grad. If return None, the
gradient is discarded.
:param regex: used to match variables. default to match all variables.
"""
self
.
func
=
func
args
=
inspect
.
getargspec
(
func
)
.
args
arg_num
=
len
(
args
)
-
inspect
.
ismethod
(
func
)
assert
arg_num
in
[
1
,
2
],
\
"The function must take 1 or 2 arguments! ({})"
.
format
(
args
)
if
arg_num
==
1
:
self
.
func
=
lambda
grad
,
var
:
func
(
grad
)
else
:
self
.
func
=
func
if
not
regex
.
endswith
(
'$'
):
regex
=
regex
+
'$'
self
.
regex
=
regex
...
...
tensorpack/tfutils/sessinit.py
View file @
52aae61a
...
...
@@ -105,8 +105,8 @@ class SaverRestore(SessionInit):
def
_get_vars_to_restore_multimap
(
self
,
vars_available
):
"""
Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore
"""
vars_to_restore
=
tf
.
all_variables
()
var_dict
=
defaultdict
(
list
)
...
...
@@ -114,12 +114,11 @@ class SaverRestore(SessionInit):
for
v
in
vars_to_restore
:
name
=
v
.
op
.
name
if
'towerp'
in
name
:
logger
.
warn
(
"Variable {} in prediction tower shouldn't exist.
"
.
format
(
v
.
name
))
logger
.
error
(
"No variable should be under 'towerp' name scope
"
.
format
(
v
.
name
))
# don't overwrite anything in the current prediction graph
continue
if
'tower'
in
name
:
new_name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
name
=
new_name
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
if
self
.
prefix
and
name
.
startswith
(
self
.
prefix
):
name
=
name
[
len
(
self
.
prefix
)
+
1
:]
if
name
in
vars_available
:
...
...
@@ -127,11 +126,11 @@ class SaverRestore(SessionInit):
chkpt_vars_used
.
add
(
name
)
#vars_available.remove(name)
else
:
logger
.
warn
(
"Variable {} not found in checkpoint!"
.
format
(
v
.
op
.
name
))
logger
.
warn
(
"Variable {}
in the graph
not found in checkpoint!"
.
format
(
v
.
op
.
name
))
if
len
(
chkpt_vars_used
)
<
len
(
vars_available
):
unused
=
vars_available
-
chkpt_vars_used
for
name
in
unused
:
logger
.
warn
(
"Variable {} in checkpoint
doesn't exist
in the graph!"
.
format
(
name
))
logger
.
warn
(
"Variable {} in checkpoint
not found
in the graph!"
.
format
(
name
))
return
var_dict
class
ParamRestore
(
SessionInit
):
...
...
@@ -155,9 +154,9 @@ class ParamRestore(SessionInit):
logger
.
info
(
"Params to restore: {}"
.
format
(
', '
.
join
(
map
(
str
,
intersect
))))
for
k
in
variable_names
-
param_names
:
logger
.
warn
(
"Variable {} in the graph not
getting restored
!"
.
format
(
k
))
logger
.
warn
(
"Variable {} in the graph not
found in the dict
!"
.
format
(
k
))
for
k
in
param_names
-
variable_names
:
logger
.
warn
(
"Variable {} in the dict not found in th
is
graph!"
.
format
(
k
))
logger
.
warn
(
"Variable {} in the dict not found in th
e
graph!"
.
format
(
k
))
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
v
.
name
in
intersect
])
...
...
tensorpack/utils/fs.py
View file @
52aae61a
...
...
@@ -6,6 +6,8 @@
import
os
,
sys
from
six.moves
import
urllib
__all__
=
[
'mkdir_p'
,
'download'
]
def
mkdir_p
(
dirname
):
""" make a dir recursively, but do nothing if the dir exists"""
assert
dirname
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment