Commit b693fe25 authored by Yuxin Wu's avatar Yuxin Wu

allow period in mergeallsummaries

parent 54dc36ba
...@@ -33,12 +33,18 @@ class MovingAverageSummary(Callback): ...@@ -33,12 +33,18 @@ class MovingAverageSummary(Callback):
class MergeAllSummaries_RunAlone(Callback): class MergeAllSummaries_RunAlone(Callback):
def __init__(self, key): def __init__(self, period, key):
self._period = period
self._key = key self._key = key
def _setup_graph(self): def _setup_graph(self):
self.summary_op = tf.summary.merge_all(self._key) self.summary_op = tf.summary.merge_all(self._key)
def _trigger_step(self):
if self._period:
if (self.local_step + 1) % self._period == 0:
self._trigger()
def _trigger(self): def _trigger(self):
if self.summary_op: if self.summary_op:
summary = self.summary_op.eval() summary = self.summary_op.eval()
...@@ -46,7 +52,8 @@ class MergeAllSummaries_RunAlone(Callback): ...@@ -46,7 +52,8 @@ class MergeAllSummaries_RunAlone(Callback):
class MergeAllSummaries_RunWithOp(Callback): class MergeAllSummaries_RunWithOp(Callback):
def __init__(self, key): def __init__(self, period, key):
self._period = period
self._key = key self._key = key
def _setup_graph(self): def _setup_graph(self):
...@@ -57,8 +64,15 @@ class MergeAllSummaries_RunWithOp(Callback): ...@@ -57,8 +64,15 @@ class MergeAllSummaries_RunWithOp(Callback):
self._fetches = None self._fetches = None
self._total = self.trainer.config.steps_per_epoch self._total = self.trainer.config.steps_per_epoch
def _before_run(self, ctx): def _need_run(self):
if self.local_step == self._total - 1: if self.local_step == self._total - 1:
return True
if self._period > 0 and (self.local_step + 1) % self._period == 0:
return True
return False
def _before_run(self, ctx):
if self._need_run():
return self._fetches return self._fetches
return None return None
...@@ -69,11 +83,13 @@ class MergeAllSummaries_RunWithOp(Callback): ...@@ -69,11 +83,13 @@ class MergeAllSummaries_RunWithOp(Callback):
self.trainer.monitors.put_summary(summary) self.trainer.monitors.put_summary(summary)
def MergeAllSummaries(run_alone=False, key=tf.GraphKeys.SUMMARIES): def MergeAllSummaries(period=0, run_alone=False, key=tf.GraphKeys.SUMMARIES):
""" """
Evaluate all summaries by `tf.summary.merge_all`, and write to logs. Evaluate all summaries by `tf.summary.merge_all`, and write to logs.
Args: Args:
period (int): by default the callback summarizes once every epoch.
This option (if not set to 0) makes it additionally summarize every ``period`` steps.
run_alone (bool): whether to evaluate the summaries alone. run_alone (bool): whether to evaluate the summaries alone.
If True, summaries will be evaluated after each epoch alone. If True, summaries will be evaluated after each epoch alone.
If False, summaries will be evaluated together with other If False, summaries will be evaluated together with other
...@@ -85,7 +101,8 @@ def MergeAllSummaries(run_alone=False, key=tf.GraphKeys.SUMMARIES): ...@@ -85,7 +101,8 @@ def MergeAllSummaries(run_alone=False, key=tf.GraphKeys.SUMMARIES):
Returns: Returns:
a Callback. a Callback.
""" """
period = int(period)
if run_alone: if run_alone:
return MergeAllSummaries_RunAlone(key) return MergeAllSummaries_RunAlone(period, key)
else: else:
return MergeAllSummaries_RunWithOp(key) return MergeAllSummaries_RunWithOp(period, key)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment