Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
509c2c90
Commit
509c2c90
authored
Aug 02, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
non-decr stat monitor param setter
parent
838a4ba3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
63 additions
and
4 deletions
+63
-4
examples/mnist-convnet.py
examples/mnist-convnet.py
+1
-1
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+53
-3
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+9
-0
No files found.
examples/mnist-convnet.py
View file @
509c2c90
...
...
@@ -87,7 +87,7 @@ def get_config():
StatPrinter
(),
ModelSaver
(),
InferenceRunner
(
dataset_test
,
[
ScalarStats
(
'cost'
),
ClassificationError
()
])
[
ScalarStats
(
'cost'
),
ClassificationError
()
])
,
]),
session_config
=
get_default_sess_config
(
0.5
),
model
=
Model
(),
...
...
tensorpack/callbacks/param.py
View file @
509c2c90
...
...
@@ -15,6 +15,7 @@ from ..tfutils import get_op_var_name
__all__
=
[
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
'ScheduledHyperParamSetter'
,
'NonDecreasingStatMonitorParamSetter'
,
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
]
class
HyperParam
(
object
):
...
...
@@ -36,7 +37,7 @@ class HyperParam(object):
return
self
.
_readable_name
class
GraphVarParam
(
HyperParam
):
""" a variable in the graph"""
""" a variable in the graph
can be a hyperparam
"""
def
__init__
(
self
,
name
,
shape
=
[]):
self
.
name
=
name
self
.
shape
=
shape
...
...
@@ -58,8 +59,11 @@ class GraphVarParam(HyperParam):
def
set_value
(
self
,
v
):
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
def
get_value
(
self
):
return
self
.
var
.
eval
()
class
ObjAttrParam
(
HyperParam
):
""" an attribute of an object"""
""" an attribute of an object
can be a hyperparam
"""
def
__init__
(
self
,
obj
,
attrname
,
readable_name
=
None
):
""" :param readable_name: default to be attrname."""
self
.
obj
=
obj
...
...
@@ -72,6 +76,9 @@ class ObjAttrParam(HyperParam):
def
set_value
(
self
,
v
):
setattr
(
self
.
obj
,
self
.
attrname
,
v
)
def
get_value
(
self
,
v
):
return
getattr
(
self
.
obj
,
self
.
attrname
)
class
HyperParamSetter
(
Callback
):
"""
Base class to set hyperparameters after every epoch.
...
...
@@ -98,11 +105,14 @@ class HyperParamSetter(Callback):
"""
ret
=
self
.
_get_value_to_set
()
if
ret
is
not
None
and
ret
!=
self
.
last_value
:
logger
.
info
(
"{} at epoch {} will change to {}"
.
format
(
logger
.
info
(
"{} at epoch {} will change to {
:.8f
}"
.
format
(
self
.
param
.
readable_name
,
self
.
epoch_num
+
1
,
ret
))
self
.
last_value
=
ret
return
ret
def
get_current_value
(
self
):
return
self
.
param
.
get_value
()
@
abstractmethod
def
_get_value_to_set
(
self
):
pass
...
...
@@ -166,3 +176,43 @@ class ScheduledHyperParamSetter(HyperParamSetter):
return
v
return
None
class
NonDecreasingStatMonitorParamSetter
(
HyperParamSetter
):
"""
Set hyperparameter by a func, if a specific stat wasn't
monotonically decreasing $a$ times out of the last $b$ epochs
"""
def
__init__
(
self
,
param
,
stat_name
,
value_func
,
last_k
=
5
,
min_non_decreasing
=
2
):
"""
Change param by `new_value = value_func(old_value)`,
if `stat_name` wasn't decreasing >=2 times in the lastest 5 times of
statistics update.
For example, if error wasn't decreasing, anneal the learning rate:
NonDecreasingStatMonitorParamSetter('learning_rate', 'val-error', lambda x: x * 0.2)
"""
super
(
NonDecreasingStatMonitorParamSetter
,
self
)
.
__init__
(
param
)
self
.
stat_name
=
stat_name
self
.
value_func
=
value_func
self
.
last_k
=
last_k
self
.
min_non_decreasing
=
min_non_decreasing
self
.
last_changed_epoch
=
0
def
_get_value_to_set
(
self
):
holder
=
self
.
trainer
.
stat_holder
hist
=
holder
.
get_stat_history
(
self
.
stat_name
)
if
len
(
hist
)
<
self
.
last_k
+
1
or
\
self
.
epoch_num
-
self
.
last_changed_epoch
<
self
.
last_k
:
return
None
hist
=
hist
[
-
self
.
last_k
-
1
:]
# len==last_k+1
cnt
=
0
for
k
in
range
(
self
.
last_k
):
if
hist
[
k
]
<=
hist
[
k
+
1
]:
cnt
+=
1
if
cnt
>=
self
.
min_non_decreasing
\
and
hist
[
-
1
]
>=
hist
[
0
]:
return
self
.
value_func
(
self
.
get_current_value
())
return
None
tensorpack/callbacks/summary.py
View file @
509c2c90
...
...
@@ -57,6 +57,15 @@ class StatHolder(object):
"""
return
self
.
stat_now
[
key
]
def
get_stat_history
(
self
,
key
):
ret
=
[]
for
h
in
self
.
stat_history
:
v
=
h
.
get
(
key
,
None
)
if
v
is
not
None
:
ret
.
append
(
v
)
v
=
self
.
stat_now
.
get
(
key
,
None
)
if
v
is
not
None
:
ret
.
append
(
v
)
return
ret
def
finalize
(
self
):
"""
Called after finishing adding stats. Will print and write stats to disk.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment