Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
16145cc8
Commit
16145cc8
authored
Jun 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
before/after_epoch callbacks and progressbar (fix #292)
parent
02f5f303
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
47 additions
and
11 deletions
+47
-11
docs/tutorial/extend/callback.md
docs/tutorial/extend/callback.md
+2
-0
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+29
-5
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+8
-0
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+6
-6
tensorpack/train/base.py
tensorpack/train/base.py
+2
-0
No files found.
docs/tutorial/extend/callback.md
View file @
16145cc8
...
...
@@ -11,9 +11,11 @@ def main_loop():
# start training:
callbacks
.
before_train
()
for
epoch
in
range
(
epoch_start
,
epoch_end
):
callbacks
.
before_epoch
()
for
step
in
range
(
steps_per_epoch
):
run_step
()
# callbacks.{before,after}_run are hooked with session
callbacks
.
trigger_step
()
callbacks
.
after_epoch
()
callbacks
.
trigger_epoch
()
callbacks
.
after_train
()
```
...
...
tensorpack/callbacks/base.py
View file @
16145cc8
...
...
@@ -59,8 +59,7 @@ class Callback(object):
def
_before_train
(
self
):
"""
Called right before the first iteration. The main difference to
`setup_graph` is that at this point the graph is finalized and a
default session is initialized.
`setup_graph` is that at this point the graph is finalized and a default session is initialized.
Override this method to, e.g. run some operations under the session.
This is similar to ``tf.train.SessionRunHook.after_create_session()``, but different:
...
...
@@ -68,6 +67,28 @@ class Callback(object):
"""
pass
def
before_epoch
(
self
):
self
.
_before_epoch
()
def
_before_epoch
(
self
):
"""
Called right before each epoch.
Usually you should use the :meth:`trigger` callback to run something between epochs.
Use this method only when something really needs to be run **immediately** before each epoch.
"""
pass
def
after_epoch
(
self
):
self
.
_after_epoch
()
def
_after_epoch
(
self
):
"""
Called right after each epoch.
Usually you should use the :meth:`trigger` callback to run something between epochs.
Use this method only when something really needs to be run **immediately** after each epoch.
"""
pass
def
before_run
(
self
,
ctx
):
fetches
=
self
.
_before_run
(
ctx
)
if
fetches
is
None
:
...
...
@@ -92,9 +113,6 @@ class Callback(object):
registers some extra op/tensors to run in the next call.
This method is the same as ``tf.train.SessionRunHook.before_run``.
Refer to TensorFlow docs for more details.
An extra feature is that you can also simply return a list of names,
instead of a ``tf.train.SessionRunArgs``.
"""
return
None
...
...
@@ -213,6 +231,12 @@ class ProxyCallback(Callback):
def
_after_train
(
self
):
self
.
cb
.
after_train
()
def
_before_epoch
(
self
):
self
.
cb
.
before_epoch
()
def
_after_epoch
(
self
):
self
.
cb
.
after_epoch
()
def
_before_run
(
self
,
ctx
):
self
.
cb
.
_before_run
(
ctx
)
...
...
tensorpack/callbacks/group.py
View file @
16145cc8
...
...
@@ -104,6 +104,14 @@ class Callbacks(Callback):
cb
.
trigger_epoch
()
tm
.
log
()
def
_before_epoch
(
self
):
for
cb
in
self
.
cbs
:
cb
.
before_epoch
()
def
_after_epoch
(
self
):
for
cb
in
self
.
cbs
:
cb
.
after_epoch
()
def
append
(
self
,
cb
):
assert
isinstance
(
cb
,
Callback
)
self
.
cbs
.
append
(
cb
)
tensorpack/callbacks/steps.py
View file @
16145cc8
...
...
@@ -106,14 +106,16 @@ class ProgressBar(Callback):
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
_fetches
)
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
def
_before_epoch
(
self
):
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
def
_after_epoch
(
self
):
self
.
_bar
.
close
()
def
_before_run
(
self
,
_
):
# update progress bar when local step changed (one step is finished)
if
self
.
local_step
!=
self
.
_last_updated
:
self
.
_last_updated
=
self
.
local_step
if
self
.
local_step
==
0
:
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
return
self
.
_fetches
else
:
return
None
...
...
@@ -125,8 +127,6 @@ class ProgressBar(Callback):
def
_trigger_step
(
self
):
self
.
_bar
.
update
()
if
self
.
local_step
==
self
.
_total
-
1
:
self
.
_bar
.
close
()
def
_after_train
(
self
):
if
self
.
_bar
:
# training may get killed before the first step
...
...
tensorpack/train/base.py
View file @
16145cc8
...
...
@@ -174,11 +174,13 @@ class Trainer(object):
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
start_time
=
time
.
time
()
self
.
_callbacks
.
before_epoch
()
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
if
self
.
hooked_sess
.
should_stop
():
return
self
.
run_step
()
# implemented by subclass
self
.
_callbacks
.
trigger_step
()
self
.
_callbacks
.
after_epoch
()
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment