Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ccf4a5a0
Commit
ccf4a5a0
authored
Feb 20, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
rename extra_fetch to before_run (#147)
parent
39d08d47
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
31 additions
and
44 deletions
+31
-44
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+18
-14
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+2
-4
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+7
-8
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+1
-1
tensorpack/callbacks/trigger.py
tensorpack/callbacks/trigger.py
+0
-4
tensorpack/train/base.py
tensorpack/train/base.py
+3
-13
No files found.
tensorpack/callbacks/base.py
View file @
ccf4a5a0
...
...
@@ -69,18 +69,18 @@ class Callback(object):
def
_after_run
(
self
,
run_context
,
run_values
):
pass
def
extra_fetches
(
self
):
def
before_run
(
self
,
ctx
):
"""
Returns:
list: a list of elements to be fetched in every step and
passed to :meth:`trigger_step`. Elements can be
Operations/Tensors, or names of Operations/Tensors.
This function will be called only after the graph is finalized.
This function should be a pure function (i.e. no side-effect when called)
Same as ``tf.train.SessionRunHook.before_run``.
"""
fetches
=
self
.
_extra_fetches
()
fetches
=
self
.
_before_run
(
ctx
)
if
isinstance
(
fetches
,
tf
.
train
.
SessionRunArgs
):
return
fetches
if
fetches
is
None
:
return
None
# also support list of names
assert
isinstance
(
fetches
,
list
),
fetches
ret
=
[]
for
f
in
fetches
:
if
isinstance
(
f
,
(
tf
.
Tensor
,
tf
.
Operation
)):
...
...
@@ -88,10 +88,10 @@ class Callback(object):
else
:
# warn about speed
ret
.
append
(
get_op_or_tensor_by_name
(
f
))
return
ret
return
tf
.
train
.
SessionRunArgs
(
fetches
=
ret
)
def
_
extra_fetches
(
self
):
return
[]
def
_
before_run
(
self
,
ctx
):
return
None
def
trigger_epoch
(
self
):
"""
...
...
@@ -180,7 +180,11 @@ class ProxyCallback(Callback):
def
_after_train
(
self
):
self
.
cb
.
after_train
()
# TODO before/after_run
def
_before_run
(
self
,
ctx
):
self
.
cb
.
_before_run
(
ctx
)
def
_after_run
(
self
,
ctx
,
run_values
):
self
.
cb
.
_after_run
(
ctx
,
run_values
)
def
__str__
(
self
):
return
"Proxy-"
+
str
(
self
.
cb
)
...
...
tensorpack/callbacks/group.py
View file @
ccf4a5a0
...
...
@@ -18,9 +18,8 @@ class CallbackHook(tf.train.SessionRunHook):
def
__init__
(
self
,
cb
):
self
.
cb
=
cb
def
before_run
(
self
,
_
):
return
tf
.
train
.
SessionRunArgs
(
fetches
=
self
.
cb
.
extra_fetches
())
def
before_run
(
self
,
ctx
):
return
self
.
cb
.
before_run
(
ctx
)
def
after_run
(
self
,
ctx
,
vals
):
self
.
cb
.
after_run
(
ctx
,
vals
)
...
...
@@ -81,7 +80,6 @@ class Callbacks(Callback):
break
self
.
cbs
=
cbs
self
.
_extra_fetches_cache
=
None
def
_setup_graph
(
self
):
with
tf
.
name_scope
(
None
):
...
...
tensorpack/callbacks/steps.py
View file @
ccf4a5a0
...
...
@@ -38,10 +38,10 @@ class StepTensorPrinter(Callback):
def
_before_train
(
self
):
self
.
_fetches
=
get_op_or_tensor_by_name
(
self
.
_names
)
def
_
extra_fetches
(
self
):
def
_
before_run
(
self
,
_
):
return
self
.
_fetches
def
_after_run
(
self
,
ctx
,
vals
):
def
_after_run
(
self
,
_
,
vals
):
args
=
vals
.
results
assert
len
(
args
)
==
len
(
self
.
_names
),
len
(
args
)
for
n
,
v
in
zip
(
self
.
_names
,
args
):
...
...
@@ -71,13 +71,13 @@ class MaintainStepCounter(Callback):
logger
.
info
(
"Start training with global_step={}"
.
format
(
gs_val
))
self
.
_last_updated
=
self
.
trainer
.
local_step
def
_
extra_fetches
(
self
):
def
_
before_run
(
self
,
_
):
# increase global_step, when trainer.local_step changed
if
self
.
trainer
.
local_step
!=
self
.
_last_updated
:
self
.
_last_updated
=
self
.
trainer
.
local_step
return
[
self
.
gs_incr_var
.
op
]
else
:
return
[]
return
None
class
ProgressBar
(
Callback
):
...
...
@@ -101,9 +101,8 @@ class ProgressBar(Callback):
if
len
(
self
.
_names
):
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
def
_
extra_fetches
(
self
):
def
_
before_run
(
self
,
_
):
if
self
.
trainer
.
local_step
!=
self
.
_last_updated
:
# local_step == number of steps that have finished in this epoch
self
.
_last_updated
=
self
.
trainer
.
local_step
if
self
.
trainer
.
local_step
==
0
:
...
...
@@ -111,9 +110,9 @@ class ProgressBar(Callback):
return
self
.
_fetches
else
:
return
[]
return
None
def
_after_run
(
self
,
ctx
,
run_values
):
def
_after_run
(
self
,
_
,
run_values
):
res
=
run_values
.
results
if
len
(
res
):
self
.
_bar
.
set_postfix
(
zip
(
self
.
_tags
,
res
))
...
...
tensorpack/callbacks/summary.py
View file @
ccf4a5a0
...
...
@@ -28,5 +28,5 @@ class MovingAverageSummary(Callback):
ops
=
tf
.
get_collection
(
self
.
_collection
)
self
.
ema_op
=
tf
.
group
(
*
ops
,
name
=
'summary_moving_averages'
)
def
_
extra_fetches
(
self
):
def
_
before_run
(
self
,
_
):
return
[
self
.
ema_op
]
tensorpack/callbacks/trigger.py
View file @
ccf4a5a0
...
...
@@ -64,10 +64,6 @@ class PeriodicCallback(ProxyCallback):
Args:
cb(Callback): the callback to be triggered periodically
period(int): the period, the number of epochs for a callback to be triggered.
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
epochs any more.
"""
super
(
PeriodicCallback
,
self
)
.
__init__
(
cb
)
self
.
period
=
int
(
period
)
...
...
tensorpack/train/base.py
View file @
ccf4a5a0
...
...
@@ -40,8 +40,8 @@ class Trainer(object):
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
epoch_num (int): the
current epoch number
.
local_step (int): the
current step number (in an epoch)
.
epoch_num (int): the
number of epochs that have finished
.
local_step (int): the
number of steps that have finished in the current epoch
.
"""
def
__init__
(
self
,
config
):
...
...
@@ -65,16 +65,6 @@ class Trainer(object):
def
run_step
(
self
):
""" Abstract method. Run one iteration. """
def
get_extra_fetches
(
self
):
"""
Returns:
list: list of tensors/ops to fetch in each step.
This function should only get called after :meth:`setup()` has finished.
"""
# TODO remove this func
return
[]
def
trigger_epoch
(
self
):
"""
Called after each epoch.
...
...
@@ -162,7 +152,7 @@ class Trainer(object):
try
:
return
self
.
_starting_step
+
\
self
.
config
.
steps_per_epoch
*
(
self
.
epoch_num
-
1
)
+
\
self
.
local_step
+
1
self
.
local_step
+
1
# +1: the ongoing step
except
AttributeError
:
return
get_global_step_value
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment