Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
39d08d47
Commit
39d08d47
authored
Feb 20, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
split trigger_step and after_run (#147)
parent
f80843dc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
25 deletions
+33
-25
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+13
-10
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+6
-3
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+11
-10
tensorpack/callbacks/trigger.py
tensorpack/callbacks/trigger.py
+2
-2
tensorpack/train/base.py
tensorpack/train/base.py
+1
-0
No files found.
tensorpack/callbacks/base.py
View file @
39d08d47
...
@@ -54,18 +54,19 @@ class Callback(object):
...
@@ -54,18 +54,19 @@ class Callback(object):
def
_before_train
(
self
):
def
_before_train
(
self
):
pass
pass
def
trigger_step
(
self
,
*
args
):
def
trigger_step
(
self
):
"""
"""
Callback to be triggered after every step (every backpropagation).
Callback to be triggered after every run_step.
"""
self
.
_trigger_step
()
Args
:
def
_trigger_step
(
self
)
:
args: a list of values corresponding to :meth:`extra_fetches`.
pass
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
def
after_run
(
self
,
run_context
,
run_values
):
"""
self
.
_after_run
(
run_context
,
run_values
)
self
.
_trigger_step
(
*
args
)
def
_
trigger_step
(
self
,
*
arg
s
):
def
_
after_run
(
self
,
run_context
,
run_value
s
):
pass
pass
def
extra_fetches
(
self
):
def
extra_fetches
(
self
):
...
@@ -173,12 +174,14 @@ class ProxyCallback(Callback):
...
@@ -173,12 +174,14 @@ class ProxyCallback(Callback):
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
self
.
cb
.
trigger_epoch
()
self
.
cb
.
trigger_epoch
()
def
_trigger_step
(
self
,
*
args
):
def
_trigger_step
(
self
):
self
.
cb
.
trigger_step
(
*
args
)
self
.
cb
.
trigger_step
()
def
_after_train
(
self
):
def
_after_train
(
self
):
self
.
cb
.
after_train
()
self
.
cb
.
after_train
()
# TODO before/after_run
def
__str__
(
self
):
def
__str__
(
self
):
return
"Proxy-"
+
str
(
self
.
cb
)
return
"Proxy-"
+
str
(
self
.
cb
)
...
...
tensorpack/callbacks/group.py
View file @
39d08d47
...
@@ -22,9 +22,8 @@ class CallbackHook(tf.train.SessionRunHook):
...
@@ -22,9 +22,8 @@ class CallbackHook(tf.train.SessionRunHook):
return
tf
.
train
.
SessionRunArgs
(
return
tf
.
train
.
SessionRunArgs
(
fetches
=
self
.
cb
.
extra_fetches
())
fetches
=
self
.
cb
.
extra_fetches
())
def
after_run
(
self
,
_
,
vals
):
def
after_run
(
self
,
ctx
,
vals
):
res
=
vals
.
results
self
.
cb
.
after_run
(
ctx
,
vals
)
self
.
cb
.
trigger_step
(
*
res
)
class
CallbackTimeLogger
(
object
):
class
CallbackTimeLogger
(
object
):
...
@@ -104,6 +103,10 @@ class Callbacks(Callback):
...
@@ -104,6 +103,10 @@ class Callbacks(Callback):
def
get_hooks
(
self
):
def
get_hooks
(
self
):
return
[
CallbackHook
(
cb
)
for
cb
in
self
.
cbs
]
return
[
CallbackHook
(
cb
)
for
cb
in
self
.
cbs
]
def
trigger_step
(
self
):
for
cb
in
self
.
cbs
:
cb
.
trigger_step
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
tm
=
CallbackTimeLogger
()
...
...
tensorpack/callbacks/steps.py
View file @
39d08d47
...
@@ -41,7 +41,8 @@ class StepTensorPrinter(Callback):
...
@@ -41,7 +41,8 @@ class StepTensorPrinter(Callback):
def
_extra_fetches
(
self
):
def
_extra_fetches
(
self
):
return
self
.
_fetches
return
self
.
_fetches
def
_trigger_step
(
self
,
*
args
):
def
_after_run
(
self
,
ctx
,
vals
):
args
=
vals
.
results
assert
len
(
args
)
==
len
(
self
.
_names
),
len
(
args
)
assert
len
(
args
)
==
len
(
self
.
_names
),
len
(
args
)
for
n
,
v
in
zip
(
self
.
_names
,
args
):
for
n
,
v
in
zip
(
self
.
_names
,
args
):
logger
.
info
(
"{}: {}"
.
format
(
n
,
v
))
logger
.
info
(
"{}: {}"
.
format
(
n
,
v
))
...
@@ -107,17 +108,17 @@ class ProgressBar(Callback):
...
@@ -107,17 +108,17 @@ class ProgressBar(Callback):
if
self
.
trainer
.
local_step
==
0
:
if
self
.
trainer
.
local_step
==
0
:
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
else
:
self
.
_bar
.
update
()
# XXX TODO move this to trigger_step after rename
if
self
.
trainer
.
local_step
==
self
.
_total
-
1
:
self
.
_bar
.
close
()
return
self
.
_fetches
return
self
.
_fetches
else
:
else
:
return
[]
return
[]
def
_trigger_step
(
self
,
*
args
):
def
_after_run
(
self
,
ctx
,
run_values
):
if
len
(
args
):
res
=
run_values
.
results
self
.
_bar
.
set_postfix
(
zip
(
self
.
_tags
,
args
))
if
len
(
res
):
self
.
_bar
.
set_postfix
(
zip
(
self
.
_tags
,
res
))
def
_trigger_step
(
self
):
self
.
_bar
.
update
()
if
self
.
trainer
.
local_step
==
self
.
_total
-
1
:
self
.
_bar
.
close
()
tensorpack/callbacks/trigger.py
View file @
39d08d47
...
@@ -31,7 +31,7 @@ class PeriodicTrigger(ProxyCallback):
...
@@ -31,7 +31,7 @@ class PeriodicTrigger(ProxyCallback):
self
.
_step_k
=
every_k_steps
self
.
_step_k
=
every_k_steps
self
.
_epoch_k
=
every_k_epochs
self
.
_epoch_k
=
every_k_epochs
def
_trigger_step
(
self
,
*
args
):
def
_trigger_step
(
self
):
if
self
.
_step_k
is
None
:
if
self
.
_step_k
is
None
:
return
return
# trigger_step is triggered after run_step, so
# trigger_step is triggered after run_step, so
...
@@ -39,7 +39,7 @@ class PeriodicTrigger(ProxyCallback):
...
@@ -39,7 +39,7 @@ class PeriodicTrigger(ProxyCallback):
if
(
self
.
trainer
.
local_step
+
1
)
%
self
.
_step_k
==
0
:
if
(
self
.
trainer
.
local_step
+
1
)
%
self
.
_step_k
==
0
:
self
.
cb
.
trigger
()
self
.
cb
.
trigger
()
def
_trigger_epoch
(
self
,
*
args
):
def
_trigger_epoch
(
self
):
if
self
.
_epoch_k
is
None
:
if
self
.
_epoch_k
is
None
:
return
return
if
self
.
epoch_num
%
self
.
_epoch_k
==
0
:
if
self
.
epoch_num
%
self
.
_epoch_k
==
0
:
...
...
tensorpack/train/base.py
View file @
39d08d47
...
@@ -183,6 +183,7 @@ class Trainer(object):
...
@@ -183,6 +183,7 @@ class Trainer(object):
if
self
.
monitored_sess
.
should_stop
():
if
self
.
monitored_sess
.
should_stop
():
return
return
self
.
run_step
()
# implemented by subclass
self
.
run_step
()
# implemented by subclass
callbacks
.
trigger_step
()
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment