Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a71ff4d7
Commit
a71ff4d7
authored
Feb 20, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
move summary_op from trainer to callbacks. fix #125
parent
658529d5
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
21 deletions
+49
-21
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+45
-1
tensorpack/train/base.py
tensorpack/train/base.py
+1
-5
tensorpack/train/config.py
tensorpack/train/config.py
+3
-2
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+0
-7
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+0
-6
No files found.
tensorpack/callbacks/summary.py
View file @
a71ff4d7
...
...
@@ -8,7 +8,7 @@ import tensorflow as tf
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
from
.base
import
Callback
__all__
=
[
'MovingAverageSummary'
]
__all__
=
[
'MovingAverageSummary'
,
'MergeAllSummaries'
]
class
MovingAverageSummary
(
Callback
):
...
...
@@ -30,3 +30,47 @@ class MovingAverageSummary(Callback):
def
_before_run
(
self
,
_
):
return
[
self
.
ema_op
]
class
MergeAllSummaries
(
Callback
):
"""
Evaluate all summaries by `tf.summary.merge_all`, and write to logs.
"""
def
__init__
(
self
,
run_alone
=
False
,
key
=
tf
.
GraphKeys
.
SUMMARIES
):
"""
Args:
run_alone (bool): whether to eval the summaries alone.
If True, summaries will be evaluated after each epoch alone.
If False, summaries will be evaluated together with other
`sess.run` calls, in the last step of each epoch.
For :class:`SimpleTrainer`, it has to be False.
key (str): the collection of summary tensors. Same as in `tf.summary.merge_all`.
"""
self
.
_run_alone
=
run_alone
self
.
_key
=
key
def
_setup_graph
(
self
):
self
.
summary_op
=
tf
.
summary
.
merge_all
(
self
.
_key
)
if
self
.
summary_op
is
not
None
:
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
summary_op
)
else
:
self
.
_fetches
=
None
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
def
_before_run
(
self
,
ctx
):
if
self
.
_run_alone
:
return
None
if
self
.
trainer
.
local_step
==
self
.
_total
-
1
:
return
self
.
_fetches
return
None
def
_after_run
(
self
,
_
,
run_values
):
summary
=
run_values
.
results
if
summary
is
None
:
return
self
.
trainer
.
add_summary
(
summary
)
def
_trigger_epoch
(
self
):
if
self
.
_run_alone
:
summary
=
self
.
summary_op
.
eval
()
self
.
trainer
.
add_summary
(
summary
)
tensorpack/train/base.py
View file @
a71ff4d7
...
...
@@ -38,7 +38,6 @@ class Trainer(object):
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
epoch_num (int): the number of epochs that have finished.
local_step (int): the number of steps that have finished in the current epoch.
...
...
@@ -75,7 +74,6 @@ class Trainer(object):
self
.
config
.
callbacks
.
trigger_epoch
()
self
.
summary_writer
.
flush
()
@
abstractmethod
def
_trigger_epoch
(
self
):
pass
...
...
@@ -121,7 +119,6 @@ class Trainer(object):
# some final operations that might modify the graph
logger
.
info
(
"Setup summaries ..."
)
self
.
summary_writer
=
tf
.
summary
.
FileWriter
(
logger
.
LOG_DIR
,
graph
=
tf
.
get_default_graph
())
self
.
summary_op
=
tf
.
summary
.
merge_all
()
# XXX not good
# create an empty StatHolder
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
...
...
@@ -178,8 +175,7 @@ class Trainer(object):
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
# trigger epoch outside the timing region.
self
.
trigger_epoch
()
self
.
trigger_epoch
()
# trigger epoch outside the timing region.
except
StopTraining
:
logger
.
info
(
"Training was stopped."
)
except
KeyboardInterrupt
:
...
...
tensorpack/train/config.py
View file @
a71ff4d7
...
...
@@ -6,7 +6,7 @@ import tensorflow as tf
from
..callbacks
import
(
Callbacks
,
MovingAverageSummary
,
StatPrinter
,
ProgressBar
,
StatPrinter
,
ProgressBar
,
MergeAllSummaries
,
MaintainStepCounter
)
from
..dataflow.base
import
DataFlow
from
..models
import
ModelDesc
...
...
@@ -41,7 +41,7 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
``[MovingAverageSummary(), ProgressBar(), StatPrinter()]``. The list of
``[MovingAverageSummary(), ProgressBar(),
MergeAllSummaries(),
StatPrinter()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
Note that ``StatPrinter`` should be the last one to be able to print
stats generated by other callbacks.
...
...
@@ -86,6 +86,7 @@ class TrainConfig(object):
extra_callbacks
=
[
MovingAverageSummary
(),
ProgressBar
(),
MergeAllSummaries
(),
StatPrinter
()]
self
.
callbacks
=
[
MaintainStepCounter
()]
+
callbacks
+
extra_callbacks
assert_type
(
self
.
callbacks
,
list
)
...
...
tensorpack/train/feedfree.py
View file @
a71ff4d7
...
...
@@ -20,13 +20,6 @@ class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`.
"""
def
_trigger_epoch
(
self
):
# run summary_op every epoch
# TODO FIXME summary_op will take a data! This is not good for TensorInput.
if
self
.
summary_op
is
not
None
:
summary_str
=
self
.
summary_op
.
eval
()
self
.
add_summary
(
summary_str
)
def
build_train_tower
(
self
):
"""
Get input tensors from `self.input_method` and build the graph.
...
...
tensorpack/train/trainer.py
View file @
a71ff4d7
...
...
@@ -101,12 +101,6 @@ class SimpleTrainer(Trainer):
grads
=
opt
.
compute_gradients
(
cost_var
)
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
def
_trigger_epoch
(
self
):
if
self
.
summary_op
is
not
None
:
feed
=
self
.
_input_method
.
last_feed
()
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
feed
)
self
.
add_summary
(
summary_str
)
def
get_predict_func
(
self
,
input_names
,
output_names
):
return
self
.
_predictor_factory
.
get_predictor
(
input_names
,
output_names
,
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment