Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
030a1d31
Commit
030a1d31
authored
Aug 07, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
param dumper
parent
6a2425d0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
46 additions
and
15 deletions
+46
-15
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+1
-3
examples/Atari2600/common.py
examples/Atari2600/common.py
+11
-2
scripts/dump-model-params.py
scripts/dump-model-params.py
+23
-7
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+3
-3
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+7
-0
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+1
-0
No files found.
examples/Atari2600/DQN.py
View file @
030a1d31
...
...
@@ -115,9 +115,7 @@ class Model(ModelDesc):
target
=
reward
+
(
1.0
-
tf
.
cast
(
isOver
,
tf
.
float32
))
*
GAMMA
*
tf
.
stop_gradient
(
best_v
)
sqrcost
=
tf
.
square
(
target
-
pred_action_value
)
abscost
=
tf
.
abs
(
target
-
pred_action_value
)
# robust error func
cost
=
tf
.
select
(
abscost
<
1
,
sqrcost
,
abscost
)
cost
=
symbf
.
clipped_l2_loss
(
target
-
pred_action_value
)
summary
.
add_param_summary
([(
'conv.*/W'
,
[
'histogram'
,
'rms'
]),
(
'fc.*/W'
,
[
'histogram'
,
'rms'
])
])
# monitor all W
self
.
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cost'
)
...
...
examples/Atari2600/common.py
View file @
030a1d31
...
...
@@ -39,12 +39,21 @@ def eval_with_funcs(predict_funcs, nr_eval):
class
Worker
(
StoppableThread
):
def
__init__
(
self
,
func
,
queue
):
super
(
Worker
,
self
)
.
__init__
()
self
.
func
=
func
self
.
_
func
=
func
self
.
q
=
queue
def
func
(
self
,
*
args
,
**
kwargs
):
if
self
.
stopped
():
raise
RuntimeError
(
"stopped!"
)
return
self
.
_func
(
*
args
,
**
kwargs
)
def
run
(
self
):
player
=
get_player
()
while
not
self
.
stopped
():
score
=
play_one_episode
(
player
,
self
.
func
)
try
:
score
=
play_one_episode
(
player
,
self
.
func
)
except
RuntimeError
:
return
self
.
queue_put_stoppable
(
self
.
q
,
score
)
q
=
queue
.
Queue
(
maxsize
=
2
)
...
...
scripts/dump-model-params.py
View file @
030a1d31
...
...
@@ -7,23 +7,39 @@ import argparse
import
tensorflow
as
tf
import
imp
from
tensorpack
.utils
import
*
from
tensorpack
import
*
from
tensorpack.tfutils
import
sessinit
,
varmanip
from
tensorpack.dataflow
import
*
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
dest
=
'config'
)
parser
.
add_argument
(
'--config'
,
help
=
'config file'
)
parser
.
add_argument
(
'--meta'
,
help
=
'metagraph file'
)
parser
.
add_argument
(
dest
=
'model'
)
parser
.
add_argument
(
dest
=
'output'
)
args
=
parser
.
parse_args
()
get_config_func
=
imp
.
load_source
(
'config_script'
,
args
.
config
)
.
get_config
assert
args
.
config
or
args
.
meta
,
"Either config or metagraph must be present!"
with
tf
.
Graph
()
.
as_default
()
as
G
:
config
=
get_config_func
()
config
.
model
.
build_graph
(
config
.
model
.
get_input_vars
(),
is_training
=
False
)
if
args
.
config
:
MODEL
=
imp
.
load_source
(
'config_script'
,
args
.
config
)
.
Model
M
=
MODEL
()
M
.
build_graph
(
M
.
get_input_vars
(),
is_training
=
False
)
else
:
M
=
ModelFromMetaGraph
(
args
.
meta
)
# loading...
init
=
sessinit
.
SaverRestore
(
args
.
model
)
sess
=
tf
.
Session
()
init
.
init
(
sess
)
# dump ...
with
sess
.
as_default
():
varmanip
.
dump_session_params
(
args
.
output
)
if
args
.
output
.
endswith
(
'npy'
):
varmanip
.
dump_session_params
(
args
.
output
)
else
:
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var
.
extend
(
tf
.
get_collection
(
EXTRA_SAVE_VARS_KEY
))
logger
.
info
(
"Variables to dump:"
)
logger
.
info
(
", "
.
join
([
v
.
name
for
v
in
var
]))
saver
=
tf
.
train
.
Saver
(
var_list
=
var
)
saver
.
save
(
sess
,
args
.
output
,
write_meta_graph
=
False
)
tensorpack/RL/gymenv.py
View file @
030a1d31
...
...
@@ -24,9 +24,9 @@ class GymEnv(RLEnvironment):
"""
def
__init__
(
self
,
name
,
dumpdir
=
None
,
viz
=
False
):
self
.
gymenv
=
gym
.
make
(
name
)
#
if dumpdir:
#
mkdir_p(dumpdir)
#self.gymenv.monitor.start(dumpdir, force=True, seed=0
)
if
dumpdir
:
mkdir_p
(
dumpdir
)
self
.
gymenv
.
monitor
.
start
(
dumpdir
)
self
.
reset_stat
()
self
.
rwd_counter
=
StatCounter
()
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
030a1d31
...
...
@@ -78,6 +78,13 @@ def rms(x, name=None):
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
def
clipped_l2_loss
(
x
,
name
=
None
):
if
name
is
None
:
name
=
'clipped_l2_loss'
sqrcost
=
tf
.
square
(
x
)
abscost
=
tf
.
abs
(
x
)
return
tf
.
select
(
abscost
<
1
,
sqrcost
,
abscost
,
name
=
name
)
def
get_scalar_var
(
name
,
init_value
):
return
tf
.
get_variable
(
name
,
shape
=
[],
initializer
=
tf
.
constant_initializer
(
init_value
),
...
...
tensorpack/tfutils/varmanip.py
View file @
030a1d31
...
...
@@ -9,6 +9,7 @@ from collections import defaultdict
import
re
import
numpy
as
np
from
..utils
import
logger
from
..utils.naming
import
*
__all__
=
[
'SessionUpdate'
,
'dump_session_params'
,
'dump_chkpt_vars'
,
'get_savename_from_varname'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment