Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2d720b60
Commit
2d720b60
authored
Mar 25, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
schedule hyper param setter
parent
0dbfe237
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
9 deletions
+28
-9
examples/cifar10_resnet.py
examples/cifar10_resnet.py
+3
-5
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+25
-4
No files found.
examples/cifar10_resnet.py
View file @
2d720b60
...
...
@@ -147,11 +147,7 @@ def get_config():
sess_config
=
get_default_sess_config
(
0.9
)
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-1
,
global_step
=
get_global_step_var
(),
decay_steps
=
36000
,
decay_rate
=
0.1
,
staircase
=
True
,
name
=
'learning_rate'
)
lr
=
tf
.
Variable
(
0.1
,
trainable
=
False
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
return
TrainConfig
(
...
...
@@ -161,6 +157,8 @@ def get_config():
StatPrinter
(),
PeriodicSaver
(),
ValidationError
(
dataset_test
,
prefix
=
'test'
),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
82
,
0.01
),
(
123
,
0.001
),
(
300
,
0.0001
)])
]),
session_config
=
sess_config
,
model
=
Model
(
n
=
18
),
...
...
tensorpack/callbacks/param.py
View file @
2d720b60
...
...
@@ -5,10 +5,13 @@
import
tensorflow
as
tf
from
abc
import
abstractmethod
,
ABCMeta
import
operator
from
.base
import
Callback
from
..utils
import
logger
,
get_op_var_name
__all__
=
[
'HyperParamSetter'
,
'HumanHyperParamSetter'
]
__all__
=
[
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
'ScheduledHyperParamSetter'
]
class
HyperParamSetter
(
Callback
):
__metaclass__
=
ABCMeta
...
...
@@ -35,9 +38,9 @@ class HyperParamSetter(Callback):
def
get_current_value
(
self
):
ret
=
self
.
_get_current_value
()
if
ret
!=
self
.
last_value
:
if
ret
is
not
None
and
ret
!=
self
.
last_value
:
logger
.
info
(
"{} at epoch {} is changed to {}"
.
format
(
self
.
var
_name
,
self
.
epoch_num
,
ret
))
self
.
op
_name
,
self
.
epoch_num
,
ret
))
self
.
last_value
=
ret
return
ret
...
...
@@ -47,6 +50,7 @@ class HyperParamSetter(Callback):
def
_trigger_epoch
(
self
):
v
=
self
.
get_current_value
()
if
v
is
not
None
:
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
})
class
HumanHyperParamSetter
(
HyperParamSetter
):
...
...
@@ -64,3 +68,20 @@ class HumanHyperParamSetter(HyperParamSetter):
lines
=
[
s
.
strip
()
.
split
(
':'
)
for
s
in
lines
]
dic
=
{
str
(
k
):
float
(
v
)
for
k
,
v
in
lines
}
return
dic
[
self
.
op_name
]
class
ScheduledHyperParamSetter
(
HyperParamSetter
):
def
__init__
(
self
,
var_name
,
schedule
):
"""
schedule: [(epoch1, val1), (epoch2, val2), (epoch3, val3), ...]
"""
self
.
schedule
=
sorted
(
schedule
,
key
=
operator
.
itemgetter
(
0
))
super
(
ScheduledHyperParamSetter
,
self
)
.
__init__
(
var_name
)
def
_get_current_value
(
self
):
for
e
,
v
in
self
.
schedule
:
if
e
==
self
.
epoch_num
:
return
v
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment