Commit a71ff4d7 authored by Yuxin Wu's avatar Yuxin Wu

move summary_op from trainer to callbacks. fix #125

parent 658529d5
...@@ -8,7 +8,7 @@ import tensorflow as tf ...@@ -8,7 +8,7 @@ import tensorflow as tf
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .base import Callback from .base import Callback
__all__ = ['MovingAverageSummary'] __all__ = ['MovingAverageSummary', 'MergeAllSummaries']
class MovingAverageSummary(Callback): class MovingAverageSummary(Callback):
...@@ -30,3 +30,47 @@ class MovingAverageSummary(Callback): ...@@ -30,3 +30,47 @@ class MovingAverageSummary(Callback):
def _before_run(self, _): def _before_run(self, _):
return [self.ema_op] return [self.ema_op]
class MergeAllSummaries(Callback):
"""
Evaluate all summaries by `tf.summary.merge_all`, and write to logs.
"""
def __init__(self, run_alone=False, key=tf.GraphKeys.SUMMARIES):
"""
Args:
run_alone (bool): whether to eval the summaries alone.
If True, summaries will be evaluated after each epoch alone.
If False, summaries will be evaluated together with other
`sess.run` calls, in the last step of each epoch.
For :class:`SimpleTrainer`, it has to be False.
key (str): the collection of summary tensors. Same as in `tf.summary.merge_all`.
"""
self._run_alone = run_alone
self._key = key
def _setup_graph(self):
self.summary_op = tf.summary.merge_all(self._key)
if self.summary_op is not None:
self._fetches = tf.train.SessionRunArgs(self.summary_op)
else:
self._fetches = None
self._total = self.trainer.config.steps_per_epoch
def _before_run(self, ctx):
if self._run_alone:
return None
if self.trainer.local_step == self._total - 1:
return self._fetches
return None
def _after_run(self, _, run_values):
summary = run_values.results
if summary is None:
return
self.trainer.add_summary(summary)
def _trigger_epoch(self):
if self._run_alone:
summary = self.summary_op.eval()
self.trainer.add_summary(summary)
...@@ -38,7 +38,6 @@ class Trainer(object): ...@@ -38,7 +38,6 @@ class Trainer(object):
stat_holder (StatHolder) stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter) summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
epoch_num (int): the number of epochs that have finished. epoch_num (int): the number of epochs that have finished.
local_step (int): the number of steps that have finished in the current epoch. local_step (int): the number of steps that have finished in the current epoch.
...@@ -75,7 +74,6 @@ class Trainer(object): ...@@ -75,7 +74,6 @@ class Trainer(object):
self.config.callbacks.trigger_epoch() self.config.callbacks.trigger_epoch()
self.summary_writer.flush() self.summary_writer.flush()
@abstractmethod
def _trigger_epoch(self): def _trigger_epoch(self):
pass pass
...@@ -121,7 +119,6 @@ class Trainer(object): ...@@ -121,7 +119,6 @@ class Trainer(object):
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup summaries ...") logger.info("Setup summaries ...")
self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph()) self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
self.summary_op = tf.summary.merge_all() # XXX not good
# create an empty StatHolder # create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR) self.stat_holder = StatHolder(logger.LOG_DIR)
...@@ -178,8 +175,7 @@ class Trainer(object): ...@@ -178,8 +175,7 @@ class Trainer(object):
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format( logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time)) self.epoch_num, self.global_step, time.time() - start_time))
# trigger epoch outside the timing region. self.trigger_epoch() # trigger epoch outside the timing region.
self.trigger_epoch()
except StopTraining: except StopTraining:
logger.info("Training was stopped.") logger.info("Training was stopped.")
except KeyboardInterrupt: except KeyboardInterrupt:
......
...@@ -6,7 +6,7 @@ import tensorflow as tf ...@@ -6,7 +6,7 @@ import tensorflow as tf
from ..callbacks import ( from ..callbacks import (
Callbacks, MovingAverageSummary, Callbacks, MovingAverageSummary,
StatPrinter, ProgressBar, StatPrinter, ProgressBar, MergeAllSummaries,
MaintainStepCounter) MaintainStepCounter)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
...@@ -41,7 +41,7 @@ class TrainConfig(object): ...@@ -41,7 +41,7 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training. callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are is only used to provide the defaults. The defaults are
``[MovingAverageSummary(), ProgressBar(), StatPrinter()]``. The list of ``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries(), StatPrinter()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``. callbacks that will be used in the end are ``callbacks + extra_callbacks``.
Note that ``StatPrinter`` should be the last one to be able to print Note that ``StatPrinter`` should be the last one to be able to print
stats generated by other callbacks. stats generated by other callbacks.
...@@ -86,6 +86,7 @@ class TrainConfig(object): ...@@ -86,6 +86,7 @@ class TrainConfig(object):
extra_callbacks = [ extra_callbacks = [
MovingAverageSummary(), MovingAverageSummary(),
ProgressBar(), ProgressBar(),
MergeAllSummaries(),
StatPrinter()] StatPrinter()]
self.callbacks = [MaintainStepCounter()] + callbacks + extra_callbacks self.callbacks = [MaintainStepCounter()] + callbacks + extra_callbacks
assert_type(self.callbacks, list) assert_type(self.callbacks, list)
......
...@@ -20,13 +20,6 @@ class FeedfreeTrainerBase(Trainer): ...@@ -20,13 +20,6 @@ class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster) """ A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`. Expect ``self.data`` to be a :class:`FeedfreeInput`.
""" """
def _trigger_epoch(self):
# run summary_op every epoch
# TODO FIXME summary_op will take a data! This is not good for TensorInput.
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self.add_summary(summary_str)
def build_train_tower(self): def build_train_tower(self):
""" """
Get input tensors from `self.input_method` and build the graph. Get input tensors from `self.input_method` and build the graph.
......
...@@ -101,12 +101,6 @@ class SimpleTrainer(Trainer): ...@@ -101,12 +101,6 @@ class SimpleTrainer(Trainer):
grads = opt.compute_gradients(cost_var) grads = opt.compute_gradients(cost_var)
self.train_op = opt.apply_gradients(grads, name='min_op') self.train_op = opt.apply_gradients(grads, name='min_op')
def _trigger_epoch(self):
if self.summary_op is not None:
feed = self._input_method.last_feed()
summary_str = self.summary_op.eval(feed_dict=feed)
self.add_summary(summary_str)
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
return self._predictor_factory.get_predictor(input_names, output_names, 0) return self._predictor_factory.get_predictor(input_names, output_names, 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment