Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6a2425d0
Commit
6a2425d0
authored
Aug 07, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
better variable name management
parent
4ee67733
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
23 deletions
+39
-23
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+20
-14
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+14
-6
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+4
-2
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+1
-1
No files found.
tensorpack/callbacks/common.py
View file @
6a2425d0
...
...
@@ -8,6 +8,7 @@ import re
from
.base
import
Callback
from
..utils
import
*
from
..tfutils.varmanip
import
get_savename_from_varname
__all__
=
[
'ModelSaver'
]
...
...
@@ -15,37 +16,42 @@ class ModelSaver(Callback):
"""
Save the model to logger directory.
"""
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
):
def
__init__
(
self
,
keep_recent
=
10
,
keep_freq
=
0.5
,
var_collections
=
tf
.
GraphKeys
.
VARIABLES
):
"""
:param keep_recent: see `tf.train.Saver` documentation.
:param keep_freq: see `tf.train.Saver` documentation.
"""
self
.
keep_recent
=
keep_recent
self
.
keep_freq
=
keep_freq
if
not
isinstance
(
var_collections
,
list
):
var_collections
=
[
var_collections
]
self
.
var_collections
=
var_collections
def
_setup_graph
(
self
):
vars
=
[]
for
key
in
self
.
var_collections
:
vars
.
extend
(
tf
.
get_collection
(
key
))
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
ModelSaver
.
_get_var
s
(
),
var_list
=
ModelSaver
.
_get_var
_dict
(
vars
),
max_to_keep
=
self
.
keep_recent
,
keep_checkpoint_every_n_hours
=
self
.
keep_freq
)
self
.
meta_graph_written
=
False
@
staticmethod
def
_get_vars
():
vars
=
tf
.
all_variables
()
def
_get_var_dict
(
vars
):
var_dict
=
{}
for
v
in
vars
:
name
=
v
.
name
if
re
.
match
(
'tower[p1-9]'
,
name
):
#logger.info("Skip {} when saving model.".format(name))
continue
if
'tower0/'
in
name
:
new_name
=
name
.
replace
(
'tower0/'
,
''
)
logger
.
info
(
"{} renamed to {} when saving model."
.
format
(
name
,
new_name
))
name
=
new_name
var_dict
[
name
]
=
v
name
=
get_savename_from_varname
(
v
.
name
)
if
name
not
in
var_dict
:
if
name
!=
v
.
name
:
logger
.
info
(
"{} renamed to {} when saving model."
.
format
(
v
.
name
,
name
))
var_dict
[
name
]
=
v
else
:
logger
.
warn
(
"Variable {} won't be saved
\
because {} will be saved"
.
format
(
v
.
name
,
var_dict
[
name
]
.
name
))
return
var_dict
def
_trigger_epoch
(
self
):
...
...
tensorpack/callbacks/param.py
View file @
6a2425d0
...
...
@@ -15,7 +15,7 @@ from ..tfutils import get_op_var_name
__all__
=
[
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
'ScheduledHyperParamSetter'
,
'
NonDecreasing
StatMonitorParamSetter'
,
'StatMonitorParamSetter'
,
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
]
class
HyperParam
(
object
):
...
...
@@ -176,14 +176,15 @@ class ScheduledHyperParamSetter(HyperParamSetter):
return
v
return
None
class
NonDecreasing
StatMonitorParamSetter
(
HyperParamSetter
):
class
StatMonitorParamSetter
(
HyperParamSetter
):
"""
Set hyperparameter by a func, if a specific stat wasn't
monotonically decreasing $a$ times out of the last $b$ epochs
monotonically decreasing
/increasing
$a$ times out of the last $b$ epochs
"""
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
last_k
=
5
,
min_non_decreasing
=
2
min_non_decreasing
=
2
,
reverse
=
False
):
"""
Change param by `new_value = value_func(old_value)`,
...
...
@@ -192,6 +193,8 @@ class NonDecreasingStatMonitorParamSetter(HyperParamSetter):
For example, if error wasn't decreasing, anneal the learning rate:
NonDecreasingStatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2)
If reverse==True, use 'increasing' instead of decreasing
"""
super
(
NonDecreasingStatMonitorParamSetter
,
self
)
.
__init__
(
param
)
self
.
stat_name
=
stat_name
...
...
@@ -200,6 +203,11 @@ class NonDecreasingStatMonitorParamSetter(HyperParamSetter):
self
.
min_non_decreasing
=
min_non_decreasing
self
.
last_changed_epoch
=
0
if
not
reverse
:
self
.
less_than
=
lambda
x
,
y
:
x
<=
y
else
:
self
.
less_than
=
lambda
x
,
y
:
x
>=
y
def
_get_value_to_set
(
self
):
holder
=
self
.
trainer
.
stat_holder
hist
=
holder
.
get_stat_history
(
self
.
stat_name
)
...
...
@@ -209,10 +217,10 @@ class NonDecreasingStatMonitorParamSetter(HyperParamSetter):
hist
=
hist
[
-
self
.
last_k
-
1
:]
# len==last_k+1
cnt
=
0
for
k
in
range
(
self
.
last_k
):
if
hist
[
k
]
<=
hist
[
k
+
1
]
:
if
self
.
less_than
(
hist
[
k
],
hist
[
k
+
1
])
:
cnt
+=
1
if
cnt
>=
self
.
min_non_decreasing
\
and
hist
[
-
1
]
>=
hist
[
0
]
:
and
self
.
less_than
(
hist
[
0
],
hist
[
-
1
])
:
return
self
.
value_func
(
self
.
get_current_value
())
return
None
tensorpack/tfutils/varmanip.py
View file @
6a2425d0
...
...
@@ -36,7 +36,6 @@ def get_savename_from_varname(
name
=
savename_prefix
+
'/'
+
name
return
name
class
SessionUpdate
(
object
):
""" Update the variables in a session """
def
__init__
(
self
,
sess
,
vars_to_update
):
...
...
@@ -87,7 +86,10 @@ def dump_session_params(path):
var
.
extend
(
tf
.
get_collection
(
EXTRA_SAVE_VARS_KEY
))
result
=
{}
for
v
in
var
:
name
=
v
.
name
.
replace
(
":0"
,
""
)
name
=
get_savename_from_varname
(
v
.
name
)
if
name
in
result
:
logger
.
info
(
"Variable {} would be stored instead of another with
\
the same name"
.
format
(
v
.
name
))
result
[
name
]
=
v
.
eval
()
logger
.
info
(
"Variables to save to {}:"
.
format
(
path
))
logger
.
info
(
str
(
result
.
keys
()))
...
...
tensorpack/utils/naming.py
View file @
6a2425d0
...
...
@@ -11,7 +11,7 @@ MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
# placeholders for input variables
INPUT_VARS_KEY
=
'INPUT_VARIABLES'
# variables that need to be saved, apart from trainable variables
# variables that need to be saved
for inference
, apart from trainable variables
EXTRA_SAVE_VARS_KEY
=
'EXTRA_SAVE_VARIABLES'
import
tensorflow
as
tf
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment