Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
52aae61a
Commit
52aae61a
authored
Aug 04, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
gymenv. fix gradproc. auto-restart limitlength
parent
ee227da4
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
107 additions
and
25 deletions
+107
-25
README.md
README.md
+4
-4
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+3
-1
examples/Atari2600/atari.py
examples/Atari2600/atari.py
+4
-7
tensorpack/RL/common.py
tensorpack/RL/common.py
+4
-3
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+72
-0
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+11
-2
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+7
-8
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+2
-0
No files found.
README.md
View file @
52aae61a
...
@@ -13,12 +13,12 @@ See some interesting [examples](examples) to learn about the framework:
...
@@ -13,12 +13,12 @@ See some interesting [examples](examples) to learn about the framework:
## Features:
## Features:
Focus on modularity. You just have to define the following three components to start a training
:
You need to abstract your training task into three components
:
1.
The model, or the
graph.
`models/`
has some scoped abstraction of common models.
1.
Model, or
graph.
`models/`
has some scoped abstraction of common models.
`LinearWrap`
and
`argscope`
makes large models look simpler.
`LinearWrap`
and
`argscope`
makes large models look simpler.
2.
The d
ata. tensorpack allows and encourages complex data processing.
2.
D
ata. tensorpack allows and encourages complex data processing.
+ All data producer has an unified `DataFlow` interface, allowing them to be composed to perform complex preprocessing.
+ All data producer has an unified `DataFlow` interface, allowing them to be composed to perform complex preprocessing.
+ Use Python to easily handle your own data format, yet still keep a good training speed thanks to multiprocess prefetch & TF Queue prefetch.
+ Use Python to easily handle your own data format, yet still keep a good training speed thanks to multiprocess prefetch & TF Queue prefetch.
...
@@ -30,7 +30,7 @@ Focus on modularity. You just have to define the following three components to s
...
@@ -30,7 +30,7 @@ Focus on modularity. You just have to define the following three components to s
+
Run inference on a test dataset
+
Run inference on a test dataset
With the above components defined, tensorpack trainer will run the training iterations for you.
With the above components defined, tensorpack trainer will run the training iterations for you.
Multi-GPU training is ready to use by simply
chang
ing the trainer.
Multi-GPU training is ready to use by simply
switch
ing the trainer.
## Dependencies:
## Dependencies:
...
...
examples/Atari2600/DQN.py
View file @
52aae61a
...
@@ -18,8 +18,10 @@ from tensorpack.utils.concurrency import *
...
@@ -18,8 +18,10 @@ from tensorpack.utils.concurrency import *
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.RL
import
*
from
tensorpack.RL
import
*
import
common
import
common
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
atari
import
AtariPlayer
BATCH_SIZE
=
64
BATCH_SIZE
=
64
IMAGE_SIZE
=
(
84
,
84
)
IMAGE_SIZE
=
(
84
,
84
)
...
@@ -54,7 +56,7 @@ def get_player(viz=False, train=False):
...
@@ -54,7 +56,7 @@ def get_player(viz=False, train=False):
if
not
train
:
if
not
train
:
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
HistoryFramePlayer
(
pl
,
FRAME_HISTORY
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
PreventStuckPlayer
(
pl
,
30
,
1
)
pl
=
LimitLengthPlayer
(
pl
,
2
0000
)
pl
=
LimitLengthPlayer
(
pl
,
3
0000
)
return
pl
return
pl
common
.
get_player
=
get_player
# so that eval functions in common can use the player
common
.
get_player
=
get_player
# so that eval functions in common can use the player
...
...
tensorpack/RL
/atari.py
→
examples/Atari2600
/atari.py
View file @
52aae61a
...
@@ -10,15 +10,12 @@ from collections import deque
...
@@ -10,15 +10,12 @@ from collections import deque
import
threading
import
threading
import
six
import
six
from
six.moves
import
range
from
six.moves
import
range
from
.
.utils
import
get_rng
,
logger
,
memoized
,
get_dataset_path
from
tensorpack
.utils
import
get_rng
,
logger
,
memoized
,
get_dataset_path
from
.
.utils.stat
import
StatCounter
from
tensorpack
.utils.stat
import
StatCounter
from
.envbase
import
RLEnvironment
,
DiscreteActionSpace
from
tensorpack.RL
.envbase
import
RLEnvironment
,
DiscreteActionSpace
try
:
from
ale_python_interface
import
ALEInterface
from
ale_python_interface
import
ALEInterface
except
ImportError
:
logger
.
warn
(
"Cannot import ale_python_interface, Atari won't be available."
)
__all__
=
[
'AtariPlayer'
]
__all__
=
[
'AtariPlayer'
]
...
...
tensorpack/RL/common.py
View file @
52aae61a
...
@@ -42,7 +42,7 @@ class PreventStuckPlayer(ProxyPlayer):
...
@@ -42,7 +42,7 @@ class PreventStuckPlayer(ProxyPlayer):
class
LimitLengthPlayer
(
ProxyPlayer
):
class
LimitLengthPlayer
(
ProxyPlayer
):
""" Limit the total number of actions in an episode.
""" Limit the total number of actions in an episode.
Does auto-reset, but doesn't auto-restart the underlying player.
Will auto restart the underlying player on timeout
"""
"""
def
__init__
(
self
,
player
,
limit
):
def
__init__
(
self
,
player
,
limit
):
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
...
@@ -55,11 +55,12 @@ class LimitLengthPlayer(ProxyPlayer):
...
@@ -55,11 +55,12 @@ class LimitLengthPlayer(ProxyPlayer):
if
self
.
cnt
>=
self
.
limit
:
if
self
.
cnt
>=
self
.
limit
:
isOver
=
True
isOver
=
True
if
isOver
:
if
isOver
:
self
.
cnt
=
0
self
.
finish_episode
()
self
.
restart_episode
()
return
(
r
,
isOver
)
return
(
r
,
isOver
)
def
restart_episode
(
self
):
def
restart_episode
(
self
):
s
uper
(
LimitLengthPlayer
,
self
)
.
restart_episode
()
s
elf
.
player
.
restart_episode
()
self
.
cnt
=
0
self
.
cnt
=
0
class
AutoRestartPlayer
(
ProxyPlayer
):
class
AutoRestartPlayer
(
ProxyPlayer
):
...
...
tensorpack/RL/gymenv.py
0 → 100644
View file @
52aae61a
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: gymenv.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
try
:
import
gym
except
ImportError
:
logger
.
warn
(
"Cannot import gym. GymEnv won't be available."
)
import
time
from
..utils
import
logger
from
..utils.fs
import
*
from
..utils.stat
import
*
from
.envbase
import
RLEnvironment
,
DiscreteActionSpace
class
GymEnv
(
RLEnvironment
):
"""
An OpenAI/gym wrapper. Will auto restart.
"""
def
__init__
(
self
,
name
,
dumpdir
=
None
,
viz
=
False
):
self
.
gymenv
=
gym
.
make
(
name
)
#if dumpdir:
#mkdir_p(dumpdir)
#self.gymenv.monitor.start(dumpdir, force=True, seed=0)
self
.
reset_stat
()
self
.
rwd_counter
=
StatCounter
()
self
.
restart_episode
()
self
.
viz
=
viz
def
restart_episode
(
self
):
self
.
rwd_counter
.
reset
()
self
.
_ob
=
self
.
gymenv
.
reset
()
def
finish_episode
(
self
):
self
.
stats
[
'score'
]
.
append
(
self
.
rwd_counter
.
sum
)
def
current_state
(
self
):
if
self
.
viz
:
self
.
gymenv
.
render
()
time
.
sleep
(
self
.
viz
)
return
self
.
_ob
def
action
(
self
,
act
):
self
.
_ob
,
r
,
isOver
,
info
=
self
.
gymenv
.
step
(
act
)
self
.
rwd_counter
.
feed
(
r
)
if
isOver
:
self
.
finish_episode
()
self
.
restart_episode
()
return
r
,
isOver
def
get_action_space
(
self
):
spc
=
self
.
gymenv
.
action_space
assert
isinstance
(
spc
,
gym
.
spaces
.
discrete
.
Discrete
)
return
DiscreteActionSpace
(
spc
.
n
)
if
__name__
==
'__main__'
:
env
=
GymEnv
(
'Breakout-v0'
,
viz
=
0.1
)
num
=
env
.
get_action_space
()
.
num_actions
()
from
..utils
import
*
rng
=
get_rng
(
num
)
while
True
:
act
=
rng
.
choice
(
range
(
num
))
#print act
r
,
o
=
env
.
action
(
act
)
env
.
current_state
()
if
r
!=
0
or
o
:
print
r
,
o
tensorpack/tfutils/gradproc.py
View file @
52aae61a
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
re
import
re
import
inspect
from
..utils
import
logger
from
..utils
import
logger
from
.symbolic_functions
import
rms
from
.symbolic_functions
import
rms
from
.summary
import
add_moving_summary
from
.summary
import
add_moving_summary
...
@@ -37,11 +38,19 @@ class MapGradient(GradientProcessor):
...
@@ -37,11 +38,19 @@ class MapGradient(GradientProcessor):
"""
"""
def
__init__
(
self
,
func
,
regex
=
'.*'
):
def
__init__
(
self
,
func
,
regex
=
'.*'
):
"""
"""
:param func: takes a (grad, var) pair and returns a grad. If return None, the
:param func: takes a
grad or
(grad, var) pair and returns a grad. If return None, the
gradient is discarded.
gradient is discarded.
:param regex: used to match variables. default to match all variables.
:param regex: used to match variables. default to match all variables.
"""
"""
self
.
func
=
func
args
=
inspect
.
getargspec
(
func
)
.
args
arg_num
=
len
(
args
)
-
inspect
.
ismethod
(
func
)
assert
arg_num
in
[
1
,
2
],
\
"The function must take 1 or 2 arguments! ({})"
.
format
(
args
)
if
arg_num
==
1
:
self
.
func
=
lambda
grad
,
var
:
func
(
grad
)
else
:
self
.
func
=
func
if
not
regex
.
endswith
(
'$'
):
if
not
regex
.
endswith
(
'$'
):
regex
=
regex
+
'$'
regex
=
regex
+
'$'
self
.
regex
=
regex
self
.
regex
=
regex
...
...
tensorpack/tfutils/sessinit.py
View file @
52aae61a
...
@@ -105,8 +105,8 @@ class SaverRestore(SessionInit):
...
@@ -105,8 +105,8 @@ class SaverRestore(SessionInit):
def
_get_vars_to_restore_multimap
(
self
,
vars_available
):
def
_get_vars_to_restore_multimap
(
self
,
vars_available
):
"""
"""
Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaible names available in the checkpoint, for existence checking
:param vars_available: varaible names available in the checkpoint, for existence checking
:returns: a dict of {var_name: [var, var]} to restore
"""
"""
vars_to_restore
=
tf
.
all_variables
()
vars_to_restore
=
tf
.
all_variables
()
var_dict
=
defaultdict
(
list
)
var_dict
=
defaultdict
(
list
)
...
@@ -114,12 +114,11 @@ class SaverRestore(SessionInit):
...
@@ -114,12 +114,11 @@ class SaverRestore(SessionInit):
for
v
in
vars_to_restore
:
for
v
in
vars_to_restore
:
name
=
v
.
op
.
name
name
=
v
.
op
.
name
if
'towerp'
in
name
:
if
'towerp'
in
name
:
logger
.
warn
(
"Variable {} in prediction tower shouldn't exist.
"
.
format
(
v
.
name
))
logger
.
error
(
"No variable should be under 'towerp' name scope
"
.
format
(
v
.
name
))
# don't overwrite anything in the current prediction graph
# don't overwrite anything in the current prediction graph
continue
continue
if
'tower'
in
name
:
if
'tower'
in
name
:
new_name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
name
=
new_name
if
self
.
prefix
and
name
.
startswith
(
self
.
prefix
):
if
self
.
prefix
and
name
.
startswith
(
self
.
prefix
):
name
=
name
[
len
(
self
.
prefix
)
+
1
:]
name
=
name
[
len
(
self
.
prefix
)
+
1
:]
if
name
in
vars_available
:
if
name
in
vars_available
:
...
@@ -127,11 +126,11 @@ class SaverRestore(SessionInit):
...
@@ -127,11 +126,11 @@ class SaverRestore(SessionInit):
chkpt_vars_used
.
add
(
name
)
chkpt_vars_used
.
add
(
name
)
#vars_available.remove(name)
#vars_available.remove(name)
else
:
else
:
logger
.
warn
(
"Variable {} not found in checkpoint!"
.
format
(
v
.
op
.
name
))
logger
.
warn
(
"Variable {}
in the graph
not found in checkpoint!"
.
format
(
v
.
op
.
name
))
if
len
(
chkpt_vars_used
)
<
len
(
vars_available
):
if
len
(
chkpt_vars_used
)
<
len
(
vars_available
):
unused
=
vars_available
-
chkpt_vars_used
unused
=
vars_available
-
chkpt_vars_used
for
name
in
unused
:
for
name
in
unused
:
logger
.
warn
(
"Variable {} in checkpoint
doesn't exist
in the graph!"
.
format
(
name
))
logger
.
warn
(
"Variable {} in checkpoint
not found
in the graph!"
.
format
(
name
))
return
var_dict
return
var_dict
class
ParamRestore
(
SessionInit
):
class
ParamRestore
(
SessionInit
):
...
@@ -155,9 +154,9 @@ class ParamRestore(SessionInit):
...
@@ -155,9 +154,9 @@ class ParamRestore(SessionInit):
logger
.
info
(
"Params to restore: {}"
.
format
(
logger
.
info
(
"Params to restore: {}"
.
format
(
', '
.
join
(
map
(
str
,
intersect
))))
', '
.
join
(
map
(
str
,
intersect
))))
for
k
in
variable_names
-
param_names
:
for
k
in
variable_names
-
param_names
:
logger
.
warn
(
"Variable {} in the graph not
getting restored
!"
.
format
(
k
))
logger
.
warn
(
"Variable {} in the graph not
found in the dict
!"
.
format
(
k
))
for
k
in
param_names
-
variable_names
:
for
k
in
param_names
-
variable_names
:
logger
.
warn
(
"Variable {} in the dict not found in th
is
graph!"
.
format
(
k
))
logger
.
warn
(
"Variable {} in the dict not found in th
e
graph!"
.
format
(
k
))
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
v
.
name
in
intersect
])
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
v
.
name
in
intersect
])
...
...
tensorpack/utils/fs.py
View file @
52aae61a
...
@@ -6,6 +6,8 @@
...
@@ -6,6 +6,8 @@
import
os
,
sys
import
os
,
sys
from
six.moves
import
urllib
from
six.moves
import
urllib
__all__
=
[
'mkdir_p'
,
'download'
]
def
mkdir_p
(
dirname
):
def
mkdir_p
(
dirname
):
""" make a dir recursively, but do nothing if the dir exists"""
""" make a dir recursively, but do nothing if the dir exists"""
assert
dirname
is
not
None
assert
dirname
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment