Commit 63bf3d12 authored by Yuxin Wu's avatar Yuxin Wu

stat_holder now maintain the steps itself, to work better with the step...

stat_holder now maintain the steps itself, to work better with the step callbacks which produce stats.
parent 7f6c83d8
...@@ -120,9 +120,7 @@ class Callback(object): ...@@ -120,9 +120,7 @@ class Callback(object):
@property @property
def global_step(self): def global_step(self):
return self._starting_step + \ return self.trainer.global_step
self._steps_per_epoch * (self.epoch_num - 1) + \
self.local_step
def __str__(self): def __str__(self):
return type(self).__name__ return type(self).__name__
......
...@@ -69,11 +69,6 @@ class Callbacks(Callback): ...@@ -69,11 +69,6 @@ class Callbacks(Callback):
logger.warn("StatPrinter should appear as the last element of callbacks! " logger.warn("StatPrinter should appear as the last element of callbacks! "
"This is now fixed automatically, but may not work in the future.") "This is now fixed automatically, but may not work in the future.")
break break
else:
raise ValueError("Callbacks must contain StatPrinter for stat and writer to work properly!")
nr_printer = sum([int(isinstance(cb, StatPrinter)) for cb in cbs])
if nr_printer != 1:
raise ValueError("Callbacks must contain one StatPrinter!")
self.cbs = cbs self.cbs = cbs
self._extra_fetches_cache = None self._extra_fetches_cache = None
......
...@@ -36,10 +36,18 @@ class StatHolder(object): ...@@ -36,10 +36,18 @@ class StatHolder(object):
else: else:
self.stat_history = [] self.stat_history = []
def add_stat(self, k, v): # global step of the current list of stat
self._current_gs = -1
def add_stat(self, k, v, global_step, epoch_num):
""" """
Add a stat. Add a stat.
""" """
if global_step != self._current_gs:
self._push()
self._current_gs = global_step
self.stat_now['epoch_num'] = epoch_num
self.stat_now['global_step'] = global_step
self.stat_now[k] = float(v) self.stat_now[k] = float(v)
def set_print_tag(self, print_tag): def set_print_tag(self, print_tag):
...@@ -85,13 +93,18 @@ class StatHolder(object): ...@@ -85,13 +93,18 @@ class StatHolder(object):
def finalize(self): def finalize(self):
""" """
Called after finishing adding stats for this epoch. Print and write stats to disk.
Will print and write stats to disk. This method is idempotent.
""" """
self._print_stat() self._print_stat()
self.stat_history.append(self.stat_now) self._push()
self.stat_now = {}
self._write_stat() def _push(self):
""" Note that this method is idempotent"""
if len(self.stat_now):
self.stat_history.append(self.stat_now)
self.stat_now = {}
self._write_stat()
def _print_stat(self): def _print_stat(self):
for k, v in sorted(self.stat_now.items(), key=operator.itemgetter(0)): for k, v in sorted(self.stat_now.items(), key=operator.itemgetter(0)):
...@@ -128,15 +141,8 @@ class StatPrinter(Triggerable): ...@@ -128,15 +141,8 @@ class StatPrinter(Triggerable):
self._stat_holder.set_print_tag(self.print_tag) self._stat_holder.set_print_tag(self.print_tag)
self._stat_holder.add_blacklist_tag(['global_step', 'epoch_num']) self._stat_holder.add_blacklist_tag(['global_step', 'epoch_num'])
# just try to add this stat earlier so SendStat can use
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
def _trigger(self): def _trigger(self):
# by default, add this two stat
self._stat_holder.add_stat('global_step', self.global_step)
self._stat_holder.finalize() self._stat_holder.finalize()
# this is for the next group of stat
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
class SendStat(Triggerable): class SendStat(Triggerable):
......
...@@ -109,7 +109,9 @@ class Trainer(object): ...@@ -109,7 +109,9 @@ class Trainer(object):
suffix = '-summary' # issue#6150 suffix = '-summary' # issue#6150
if val.tag.endswith(suffix): if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)] val.tag = val.tag[:-len(suffix)]
self.stat_holder.add_stat(val.tag, val.simple_value) self.stat_holder.add_stat(
val.tag, val.simple_value,
self.global_step, self.epoch_num)
self.summary_writer.add_summary(summary, get_global_step_value()) self.summary_writer.add_summary(summary, get_global_step_value())
def add_scalar_summary(self, name, val): def add_scalar_summary(self, name, val):
...@@ -151,12 +153,19 @@ class Trainer(object): ...@@ -151,12 +153,19 @@ class Trainer(object):
def _setup(self): def _setup(self):
""" setup Trainer-specific stuff for training""" """ setup Trainer-specific stuff for training"""
@property
def global_step(self):
return self._starting_step + \
self.config.steps_per_epoch * (self.epoch_num - 1) + \
self.local_step + 1
def main_loop(self): def main_loop(self):
""" """
Run the main training loop. Run the main training loop.
""" """
callbacks = self.config.callbacks callbacks = self.config.callbacks
with self.sess.as_default(): with self.sess.as_default():
self._starting_step = get_global_step_value()
try: try:
callbacks.before_train() callbacks.before_train()
for self.epoch_num in range( for self.epoch_num in range(
...@@ -173,7 +182,7 @@ class Trainer(object): ...@@ -173,7 +182,7 @@ class Trainer(object):
else: else:
callbacks.trigger_step(*fetch_data) callbacks.trigger_step(*fetch_data)
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format( logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, get_global_step_value(), time.time() - start_time)) self.epoch_num, self.global_step, time.time() - start_time))
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self.trigger_epoch() self.trigger_epoch()
......
...@@ -133,7 +133,7 @@ def get_tqdm(**kwargs): ...@@ -133,7 +133,7 @@ def get_tqdm(**kwargs):
def building_rtfd(): def building_rtfd():
""" """
Returns: Returns:
bool: if tensorpack is imported to generate docs now. bool: if tensorpack is being imported to generate docs now.
""" """
return os.environ.get('READTHEDOCS') == 'True' \ return os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING') or os.environ.get('TENSORPACK_DOC_BUILDING')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment