Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
5ae5f3e5
Commit
5ae5f3e5
authored
Jun 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
some docs about callbacks
parent
0630a31c
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
88 additions
and
23 deletions
+88
-23
docs/tutorial/callback.md
docs/tutorial/callback.md
+12
-8
docs/tutorial/extend/callback.md
docs/tutorial/extend/callback.md
+66
-1
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+1
-1
tensorpack/predict/base.py
tensorpack/predict/base.py
+3
-5
tensorpack/train/base.py
tensorpack/train/base.py
+6
-8
No files found.
docs/tutorial/callback.md
View file @
5ae5f3e5
...
...
@@ -4,7 +4,7 @@
Apart from the actual training iterations that minimize the cost,
you almost surely would like to do something else during training.
Callbacks are such an interface to describe what to do besides the
training iterations
defined by the trainers
.
training iterations.
There are several places where you might want to do something else:
...
...
@@ -14,9 +14,13 @@ There are several places where you might want to do something else:
*
Between epochs (e.g. save the model, run some validation)
*
After the training (e.g. send the model somewhere, send a message to your phone)
By writing callbacks to implement these tasks, you can reuse the code as long as
you are using tensorpack trainers. For example, these are the callbacks I used when training
a ResNet:
We found people traditionally tend to write the training loop together with these extra features.
This makes the loop lengthy, and the code for the same feature probably get separated.
By writing callbacks to implement what you want to do at each place, tensorpack trainers
will call them at the proper time.
Therefore the code can be reused with one single line, as long as you are using tensorpack trainers.
For example, these are the callbacks I used when training a ResNet:
```
python
TrainConfig
(
...
...
docs/tutorial/extend/callback.md
View file @
5ae5f3e5
## Write a callback
TODO
The places where each callback gets called is demonstrated in this snippet:
```
python
def
main_loop
():
# create graph for the model
callbacks
.
setup_graph
()
# create session, initialize session, finalize graph ...
# start training:
callbacks
.
before_train
()
for
epoch
in
range
(
epoch_start
,
epoch_end
):
for
step
in
range
(
steps_per_epoch
):
run_step
()
# callbacks.{before,after}_run are hooked with session
callbacks
.
trigger_step
()
callbacks
.
trigger_epoch
()
callbacks
.
after_train
()
```
You can overwrite any of the following methods to define a new callback:
*
`_setup_graph(self)`
To separate between "define" and "run", and also to avoid the common mistake to create ops inside
loops, all changes to the graph should be made in this method. No session has been created at this time.
TODO how to access the tensors already defined.
*
`_before_train(self)`
Can be used to run some manual initialization of variables, or start some services for the whole training.
*
`_trigger_step(self)`
Do something (including running ops) after each step has finished.
Be careful to only do light work here because it could affect training speed.
*
`_before_run(self, ctx)`
,
`_after_run(self, ctx, values)`
This two are the equivlent of
[
tf.train.SessionRunHook
](
https://www.tensorflow.org/api_docs/python/tf/train/SessionRunHook
)
.
Please refer to TensorFlow documentation for detailed API.
They are used to run extra ops / eval extra tensors / feed extra values __along with__ the actual training iteration.
Note the difference between running __along with__ an iteration and running after an iteration.
When you write
```
python
def
_before_run
(
self
,
_
):
return
tf
.
train
.
SessionRunArgs
(
fetches
=
my_op
)
```
The training loops would become
`sess.run([training_op, my_op])`
.
This is different from
`sess.run(training_op); sess.run(my_op);`
,
which is what you would get if you run the op in
`_trigger_step`
.
*
`_trigger_epoch(self)`
Do something after each epoch has finished. Will call
`self.trigger()`
by default.
*
`_trigger(self)`
By default will get called by
`_trigger_epoch`
,
but you can then customize the scheduling of this callback by
`PeriodicTrigger`
, to let this method run every k steps or every k epochs.
*
`_after_train(self)`
Do some finalization work.
examples/DeepQNetwork/DQN.py
View file @
5ae5f3e5
...
...
@@ -114,7 +114,7 @@ def get_config():
callbacks
=
[
ModelSaver
(),
PeriodicTrigger
(
RunOp
(
DQNModel
.
update_target_param
),
RunOp
(
DQNModel
.
update_target_param
,
verbose
=
True
),
every_k_steps
=
10000
//
UPDATE_FREQ
),
# update target network every 10k steps
expreplay
,
ScheduledHyperParamSetter
(
'learning_rate'
,
...
...
tensorpack/predict/base.py
View file @
5ae5f3e5
...
...
@@ -106,7 +106,7 @@ class OnlinePredictor(PredictorBase):
output_tensors (list): list of names.
return_input (bool): same as :attr:`PredictorBase.return_input`.
sess (tf.Session): the session this predictor runs in. If None,
will use the default session.
will use the default session
at the first call
.
"""
self
.
return_input
=
return_input
self
.
input_tensors
=
input_tensors
...
...
@@ -118,10 +118,8 @@ class OnlinePredictor(PredictorBase):
"{} != {}"
.
format
(
len
(
dp
),
len
(
self
.
input_tensors
))
feed
=
dict
(
zip
(
self
.
input_tensors
,
dp
))
if
self
.
sess
is
None
:
sess
=
tf
.
get_default_session
()
else
:
sess
=
self
.
sess
output
=
sess
.
run
(
self
.
output_tensors
,
feed_dict
=
feed
)
self
.
sess
=
tf
.
get_default_session
()
output
=
self
.
sess
.
run
(
self
.
output_tensors
,
feed_dict
=
feed
)
return
output
...
...
tensorpack/train/base.py
View file @
5ae5f3e5
...
...
@@ -62,13 +62,7 @@ class Trainer(object):
self
.
local_step
=
-
1
self
.
_callbacks
=
[]
self
.
register_callback
(
MaintainStepCounter
())
for
cb
in
config
.
callbacks
:
self
.
register_callback
(
cb
)
self
.
monitors
=
[]
for
m
in
config
.
monitors
:
self
.
register_monitor
(
m
)
def
register_callback
(
self
,
cb
):
"""
...
...
@@ -91,7 +85,7 @@ class Trainer(object):
assert
not
isinstance
(
self
.
monitors
,
Monitors
),
\
"Cannot register more monitors after trainer was setup!"
if
not
self
.
is_chief
and
mon
.
chief_only
:
logger
.
warn
(
"
Callback
{} is chief-only, skipped."
.
format
(
str
(
mon
)))
logger
.
warn
(
"
Monitor
{} is chief-only, skipped."
.
format
(
str
(
mon
)))
else
:
self
.
monitors
.
append
(
mon
)
self
.
register_callback
(
mon
)
...
...
@@ -115,10 +109,14 @@ class Trainer(object):
"""
self
.
_setup
()
# subclass will setup the graph
self
.
register_callback
(
MaintainStepCounter
())
for
cb
in
self
.
config
.
callbacks
:
self
.
register_callback
(
cb
)
for
m
in
self
.
config
.
monitors
:
self
.
register_monitor
(
m
)
self
.
monitors
=
Monitors
(
self
.
monitors
)
self
.
register_callback
(
self
.
monitors
)
# TODO cache per graph, avoid describing all towers
describe_model
()
# some final operations that might modify the graph
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment