Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fded4f40
Commit
fded4f40
authored
Nov 05, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Check global_step in MinSaver (fix #966)
parent
5772d5fd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
23 deletions
+35
-23
examples/basics/mnist-convnet.py
examples/basics/mnist-convnet.py
+4
-2
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+31
-21
No files found.
examples/basics/mnist-convnet.py
View file @
fded4f40
...
...
@@ -115,10 +115,12 @@ if __name__ == '__main__':
data
=
FeedInput
(
dataset_train
),
callbacks
=
[
ModelSaver
(),
# save the model after every epoch
MaxSaver
(
'validation_accuracy'
),
# save the model with highest accuracy (prefix 'validation_')
InferenceRunner
(
# run inference(for validation) after every epoch
dataset_test
,
# the DataFlow instance used for validation
ScalarStats
([
'cross_entropy_loss'
,
'accuracy'
])),
ScalarStats
(
# produce `val_accuracy` and `val_cross_entropy_loss`
[
'cross_entropy_loss'
,
'accuracy'
],
prefix
=
'val'
)),
# MaxSaver has to come after InferenceRunner
MaxSaver
(
'val_accuracy'
),
# save the model with highest accuracy
],
steps_per_epoch
=
steps_per_epoch
,
max_epoch
=
100
,
...
...
tensorpack/callbacks/saver.py
View file @
fded4f40
...
...
@@ -102,58 +102,68 @@ class MinSaver(Callback):
MinSaver('val-error')
Note:
It assumes that :class:`ModelSaver` is used with the same ``checkpoint_dir``
and appears earlier in the callback list.
The default for both :class:`ModelSaver` and :class:`MinSaver`
is ``checkpoint_dir=logger.get_logger_dir()``
1. It assumes that :class:`ModelSaver` is used with the same ``checkpoint_dir``
and appears earlier in the callback list.
The default for both :class:`ModelSaver` and :class:`MinSaver`
is ``checkpoint_dir=logger.get_logger_dir()``
2. Callbacks are executed in the order they are defined. Therefore you'd want to
use this callback after the callback (e.g. InferenceRunner) that produces the statistics.
"""
self
.
monitor_stat
=
monitor_stat
self
.
reverse
=
reverse
self
.
filename
=
filename
self
.
min
=
None
self
.
best
=
None
self
.
checkpoint_dir
=
checkpoint_dir
if
self
.
checkpoint_dir
is
None
:
self
.
checkpoint_dir
=
logger
.
get_logger_dir
()
def
_get_stat
(
self
):
try
:
v
=
self
.
trainer
.
monitors
.
get_
latest
(
self
.
monitor_stat
)
except
KeyError
:
v
=
None
v
=
self
.
trainer
.
monitors
.
get_
history
(
self
.
monitor_stat
)[
-
1
]
except
(
KeyError
,
IndexError
)
:
v
=
None
,
None
return
v
def
_need_save
(
self
):
v
=
self
.
_get_stat
()
if
not
v
:
return
False
return
v
>
self
.
min
if
self
.
reverse
else
v
<
self
.
min
def
_trigger
(
self
):
if
self
.
min
is
None
or
self
.
_need_save
():
self
.
min
=
self
.
_get_stat
()
if
self
.
min
:
self
.
_save
()
curr_step
,
curr_val
=
self
.
_get_stat
()
if
curr_step
is
None
:
return
if
self
.
best
is
None
or
(
curr_val
>
self
.
best
[
1
]
if
self
.
reverse
else
curr_val
<
self
.
best
[
1
]):
self
.
best
=
(
curr_step
,
curr_val
)
self
.
_save
()
def
_save
(
self
):
ckpt
=
tf
.
train
.
get_checkpoint_state
(
self
.
checkpoint_dir
)
if
ckpt
is
None
:
raise
RuntimeError
(
"Cannot find a checkpoint state. Do you forget to use ModelSaver?"
)
"
[MinSaver]
Cannot find a checkpoint state. Do you forget to use ModelSaver?"
)
path
=
ckpt
.
model_checkpoint_path
extreme_name
=
'maximum'
if
self
.
reverse
else
'minimum'
if
not
path
.
endswith
(
str
(
self
.
best
[
0
])):
logger
.
warn
(
"[MinSaver] New {} '{}' found at global_step={}, but the latest checkpoint is {}."
.
format
(
extreme_name
,
self
.
monitor_stat
,
self
.
best
[
0
],
path
))
logger
.
warn
(
"MinSaver will do nothing this time. "
"The callbacks may have inconsistent frequency or wrong order."
)
return
newname
=
os
.
path
.
join
(
self
.
checkpoint_dir
,
self
.
filename
or
(
'max-'
+
self
.
monitor_stat
if
self
.
reverse
else
'min-'
+
self
.
monitor_stat
))
files_to_copy
=
tf
.
gfile
.
Glob
(
path
+
'*'
)
for
file_to_copy
in
files_to_copy
:
tf
.
gfile
.
Copy
(
file_to_copy
,
file_to_copy
.
replace
(
path
,
newname
),
overwrite
=
True
)
logger
.
info
(
"Model
with {} '{}'
saved."
.
format
(
'maximum'
if
self
.
reverse
else
'minimum'
,
self
.
monitor_stat
))
logger
.
info
(
"Model
at global_step={} with {} {}={:.5g}
saved."
.
format
(
self
.
best
[
0
],
extreme_name
,
self
.
monitor_stat
,
self
.
best
[
1
]
))
class
MaxSaver
(
MinSaver
):
"""
Separately save the model with maximum value of some statistics.
See docs of :class:`MinSaver` for details.
"""
def
__init__
(
self
,
monitor_stat
,
filename
=
None
,
checkpoint_dir
=
None
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment