Commit 745c70a4 authored by Yuxin Wu's avatar Yuxin Wu

split the implementation of MergeAllSummaries

parent 89bcdd10
......@@ -32,21 +32,21 @@ class MovingAverageSummary(Callback):
return [self.ema_op]
class MergeAllSummaries(Callback):
"""
Evaluate all summaries by `tf.summary.merge_all`, and write to logs.
"""
def __init__(self, run_alone=False, key=tf.GraphKeys.SUMMARIES):
"""
Args:
run_alone (bool): whether to evaluate the summaries alone.
If True, summaries will be evaluated after each epoch alone.
If False, summaries will be evaluated together with other
`sess.run` calls, in the last step of each epoch.
For :class:`SimpleTrainer`, it has to be False.
key (str): the collection of summary tensors. Same as in `tf.summary.merge_all`.
"""
self._run_alone = run_alone
class MergeAllSummaries_RunAlone(Callback):
def __init__(self, key):
self._key = key
def _setup_graph(self):
self.summary_op = tf.summary.merge_all(self._key)
def _trigger(self):
if self.summary_op:
summary = self.summary_op.eval()
self.trainer.monitors.put_summary(summary)
class MergeAllSummaries_RunWithOp(Callback):
def __init__(self, key):
self._key = key
def _setup_graph(self):
......@@ -58,8 +58,6 @@ class MergeAllSummaries(Callback):
self._total = self.trainer.config.steps_per_epoch
def _before_run(self, ctx):
if self._run_alone:
return None
if self.local_step == self._total - 1:
return self._fetches
return None
......@@ -70,10 +68,24 @@ class MergeAllSummaries(Callback):
return
self.trainer.monitors.put_summary(summary)
def _trigger(self):
summary = self.summary_op.eval()
self.trainer.monitors.put_summary(summary)
def _trigger_epoch(self):
if self._run_alone:
self._trigger()
def MergeAllSummaries(run_alone=False, key=tf.GraphKeys.SUMMARIES):
"""
Evaluate all summaries by `tf.summary.merge_all`, and write to logs.
Args:
run_alone (bool): whether to evaluate the summaries alone.
If True, summaries will be evaluated after each epoch alone.
If False, summaries will be evaluated together with other
`sess.run` calls, in the last step of each epoch.
For :class:`SimpleTrainer`, it needs to be False because summary may
depend on inputs.
key (str): the collection of summary tensors. Same as in `tf.summary.merge_all`.
Returns:
a Callback.
"""
if run_alone:
return MergeAllSummaries_RunAlone(key)
else:
return MergeAllSummaries_RunWithOp(key)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment