Commit 745c70a4 authored by Yuxin Wu's avatar Yuxin Wu

split the implementation of MergeAllSummaries

parent 89bcdd10
...@@ -32,21 +32,21 @@ class MovingAverageSummary(Callback): ...@@ -32,21 +32,21 @@ class MovingAverageSummary(Callback):
return [self.ema_op] return [self.ema_op]
class MergeAllSummaries(Callback): class MergeAllSummaries_RunAlone(Callback):
""" def __init__(self, key):
Evaluate all summaries by `tf.summary.merge_all`, and write to logs. self._key = key
"""
def __init__(self, run_alone=False, key=tf.GraphKeys.SUMMARIES): def _setup_graph(self):
""" self.summary_op = tf.summary.merge_all(self._key)
Args:
run_alone (bool): whether to evaluate the summaries alone. def _trigger(self):
If True, summaries will be evaluated after each epoch alone. if self.summary_op:
If False, summaries will be evaluated together with other summary = self.summary_op.eval()
`sess.run` calls, in the last step of each epoch. self.trainer.monitors.put_summary(summary)
For :class:`SimpleTrainer`, it has to be False.
key (str): the collection of summary tensors. Same as in `tf.summary.merge_all`.
""" class MergeAllSummaries_RunWithOp(Callback):
self._run_alone = run_alone def __init__(self, key):
self._key = key self._key = key
def _setup_graph(self): def _setup_graph(self):
...@@ -58,8 +58,6 @@ class MergeAllSummaries(Callback): ...@@ -58,8 +58,6 @@ class MergeAllSummaries(Callback):
self._total = self.trainer.config.steps_per_epoch self._total = self.trainer.config.steps_per_epoch
def _before_run(self, ctx): def _before_run(self, ctx):
if self._run_alone:
return None
if self.local_step == self._total - 1: if self.local_step == self._total - 1:
return self._fetches return self._fetches
return None return None
...@@ -70,10 +68,24 @@ class MergeAllSummaries(Callback): ...@@ -70,10 +68,24 @@ class MergeAllSummaries(Callback):
return return
self.trainer.monitors.put_summary(summary) self.trainer.monitors.put_summary(summary)
def _trigger(self):
summary = self.summary_op.eval()
self.trainer.monitors.put_summary(summary)
def _trigger_epoch(self): def MergeAllSummaries(run_alone=False, key=tf.GraphKeys.SUMMARIES):
if self._run_alone: """
self._trigger() Evaluate all summaries by `tf.summary.merge_all`, and write to logs.
Args:
run_alone (bool): whether to evaluate the summaries alone.
If True, summaries will be evaluated after each epoch alone.
If False, summaries will be evaluated together with other
`sess.run` calls, in the last step of each epoch.
For :class:`SimpleTrainer`, it needs to be False because summary may
depend on inputs.
key (str): the collection of summary tensors. Same as in `tf.summary.merge_all`.
Returns:
a Callback.
"""
if run_alone:
return MergeAllSummaries_RunAlone(key)
else:
return MergeAllSummaries_RunWithOp(key)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment