Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
16145cc8
Commit
16145cc8
authored
Jun 12, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
before/after_epoch callbacks and progressbar (fix #292)
parent
02f5f303
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
47 additions
and
11 deletions
+47
-11
docs/tutorial/extend/callback.md
docs/tutorial/extend/callback.md
+2
-0
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+29
-5
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+8
-0
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+6
-6
tensorpack/train/base.py
tensorpack/train/base.py
+2
-0
No files found.
docs/tutorial/extend/callback.md
View file @
16145cc8
...
@@ -11,9 +11,11 @@ def main_loop():
...
@@ -11,9 +11,11 @@ def main_loop():
# start training:
# start training:
callbacks
.
before_train
()
callbacks
.
before_train
()
for
epoch
in
range
(
epoch_start
,
epoch_end
):
for
epoch
in
range
(
epoch_start
,
epoch_end
):
callbacks
.
before_epoch
()
for
step
in
range
(
steps_per_epoch
):
for
step
in
range
(
steps_per_epoch
):
run_step
()
# callbacks.{before,after}_run are hooked with session
run_step
()
# callbacks.{before,after}_run are hooked with session
callbacks
.
trigger_step
()
callbacks
.
trigger_step
()
callbacks
.
after_epoch
()
callbacks
.
trigger_epoch
()
callbacks
.
trigger_epoch
()
callbacks
.
after_train
()
callbacks
.
after_train
()
```
```
...
...
tensorpack/callbacks/base.py
View file @
16145cc8
...
@@ -59,8 +59,7 @@ class Callback(object):
...
@@ -59,8 +59,7 @@ class Callback(object):
def
_before_train
(
self
):
def
_before_train
(
self
):
"""
"""
Called right before the first iteration. The main difference to
Called right before the first iteration. The main difference to
`setup_graph` is that at this point the graph is finalized and a
`setup_graph` is that at this point the graph is finalized and a default session is initialized.
default session is initialized.
Override this method to, e.g. run some operations under the session.
Override this method to, e.g. run some operations under the session.
This is similar to ``tf.train.SessionRunHook.after_create_session()``, but different:
This is similar to ``tf.train.SessionRunHook.after_create_session()``, but different:
...
@@ -68,6 +67,28 @@ class Callback(object):
...
@@ -68,6 +67,28 @@ class Callback(object):
"""
"""
pass
pass
def
before_epoch
(
self
):
self
.
_before_epoch
()
def
_before_epoch
(
self
):
"""
Called right before each epoch.
Usually you should use the :meth:`trigger` callback to run something between epochs.
Use this method only when something really needs to be run **immediately** before each epoch.
"""
pass
def
after_epoch
(
self
):
self
.
_after_epoch
()
def
_after_epoch
(
self
):
"""
Called right after each epoch.
Usually you should use the :meth:`trigger` callback to run something between epochs.
Use this method only when something really needs to be run **immediately** after each epoch.
"""
pass
def
before_run
(
self
,
ctx
):
def
before_run
(
self
,
ctx
):
fetches
=
self
.
_before_run
(
ctx
)
fetches
=
self
.
_before_run
(
ctx
)
if
fetches
is
None
:
if
fetches
is
None
:
...
@@ -92,9 +113,6 @@ class Callback(object):
...
@@ -92,9 +113,6 @@ class Callback(object):
registers some extra op/tensors to run in the next call.
registers some extra op/tensors to run in the next call.
This method is the same as ``tf.train.SessionRunHook.before_run``.
This method is the same as ``tf.train.SessionRunHook.before_run``.
Refer to TensorFlow docs for more details.
Refer to TensorFlow docs for more details.
An extra feature is that you can also simply return a list of names,
instead of a ``tf.train.SessionRunArgs``.
"""
"""
return
None
return
None
...
@@ -213,6 +231,12 @@ class ProxyCallback(Callback):
...
@@ -213,6 +231,12 @@ class ProxyCallback(Callback):
def
_after_train
(
self
):
def
_after_train
(
self
):
self
.
cb
.
after_train
()
self
.
cb
.
after_train
()
def
_before_epoch
(
self
):
self
.
cb
.
before_epoch
()
def
_after_epoch
(
self
):
self
.
cb
.
after_epoch
()
def
_before_run
(
self
,
ctx
):
def
_before_run
(
self
,
ctx
):
self
.
cb
.
_before_run
(
ctx
)
self
.
cb
.
_before_run
(
ctx
)
...
...
tensorpack/callbacks/group.py
View file @
16145cc8
...
@@ -104,6 +104,14 @@ class Callbacks(Callback):
...
@@ -104,6 +104,14 @@ class Callbacks(Callback):
cb
.
trigger_epoch
()
cb
.
trigger_epoch
()
tm
.
log
()
tm
.
log
()
def
_before_epoch
(
self
):
for
cb
in
self
.
cbs
:
cb
.
before_epoch
()
def
_after_epoch
(
self
):
for
cb
in
self
.
cbs
:
cb
.
after_epoch
()
def
append
(
self
,
cb
):
def
append
(
self
,
cb
):
assert
isinstance
(
cb
,
Callback
)
assert
isinstance
(
cb
,
Callback
)
self
.
cbs
.
append
(
cb
)
self
.
cbs
.
append
(
cb
)
tensorpack/callbacks/steps.py
View file @
16145cc8
...
@@ -106,14 +106,16 @@ class ProgressBar(Callback):
...
@@ -106,14 +106,16 @@ class ProgressBar(Callback):
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
_fetches
)
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
_fetches
)
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
def
_before_epoch
(
self
):
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
def
_after_epoch
(
self
):
self
.
_bar
.
close
()
def
_before_run
(
self
,
_
):
def
_before_run
(
self
,
_
):
# update progress bar when local step changed (one step is finished)
# update progress bar when local step changed (one step is finished)
if
self
.
local_step
!=
self
.
_last_updated
:
if
self
.
local_step
!=
self
.
_last_updated
:
self
.
_last_updated
=
self
.
local_step
self
.
_last_updated
=
self
.
local_step
if
self
.
local_step
==
0
:
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
return
self
.
_fetches
return
self
.
_fetches
else
:
else
:
return
None
return
None
...
@@ -125,8 +127,6 @@ class ProgressBar(Callback):
...
@@ -125,8 +127,6 @@ class ProgressBar(Callback):
def
_trigger_step
(
self
):
def
_trigger_step
(
self
):
self
.
_bar
.
update
()
self
.
_bar
.
update
()
if
self
.
local_step
==
self
.
_total
-
1
:
self
.
_bar
.
close
()
def
_after_train
(
self
):
def
_after_train
(
self
):
if
self
.
_bar
:
# training may get killed before the first step
if
self
.
_bar
:
# training may get killed before the first step
...
...
tensorpack/train/base.py
View file @
16145cc8
...
@@ -174,11 +174,13 @@ class Trainer(object):
...
@@ -174,11 +174,13 @@ class Trainer(object):
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
start_time
=
time
.
time
()
start_time
=
time
.
time
()
self
.
_callbacks
.
before_epoch
()
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
if
self
.
hooked_sess
.
should_stop
():
if
self
.
hooked_sess
.
should_stop
():
return
return
self
.
run_step
()
# implemented by subclass
self
.
run_step
()
# implemented by subclass
self
.
_callbacks
.
trigger_step
()
self
.
_callbacks
.
trigger_step
()
self
.
_callbacks
.
after_epoch
()
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment