Commit 63bf3d12 authored by Yuxin Wu's avatar Yuxin Wu

stat_holder now maintain the steps itself, to work better with the step...

stat_holder now maintain the steps itself, to work better with the step callbacks which produce stats.
parent 7f6c83d8
......@@ -120,9 +120,7 @@ class Callback(object):
@property
def global_step(self):
return self._starting_step + \
self._steps_per_epoch * (self.epoch_num - 1) + \
self.local_step
return self.trainer.global_step
def __str__(self):
return type(self).__name__
......
......@@ -69,11 +69,6 @@ class Callbacks(Callback):
logger.warn("StatPrinter should appear as the last element of callbacks! "
"This is now fixed automatically, but may not work in the future.")
break
else:
raise ValueError("Callbacks must contain StatPrinter for stat and writer to work properly!")
nr_printer = sum([int(isinstance(cb, StatPrinter)) for cb in cbs])
if nr_printer != 1:
raise ValueError("Callbacks must contain one StatPrinter!")
self.cbs = cbs
self._extra_fetches_cache = None
......
......@@ -36,10 +36,18 @@ class StatHolder(object):
else:
self.stat_history = []
def add_stat(self, k, v):
# global step of the current list of stat
self._current_gs = -1
def add_stat(self, k, v, global_step, epoch_num):
"""
Add a stat.
"""
if global_step != self._current_gs:
self._push()
self._current_gs = global_step
self.stat_now['epoch_num'] = epoch_num
self.stat_now['global_step'] = global_step
self.stat_now[k] = float(v)
def set_print_tag(self, print_tag):
......@@ -85,10 +93,15 @@ class StatHolder(object):
def finalize(self):
"""
Called after finishing adding stats for this epoch.
Will print and write stats to disk.
Print and write stats to disk.
This method is idempotent.
"""
self._print_stat()
self._push()
def _push(self):
""" Note that this method is idempotent"""
if len(self.stat_now):
self.stat_history.append(self.stat_now)
self.stat_now = {}
self._write_stat()
......@@ -128,15 +141,8 @@ class StatPrinter(Triggerable):
self._stat_holder.set_print_tag(self.print_tag)
self._stat_holder.add_blacklist_tag(['global_step', 'epoch_num'])
# just try to add this stat earlier so SendStat can use
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
def _trigger(self):
# by default, add this two stat
self._stat_holder.add_stat('global_step', self.global_step)
self._stat_holder.finalize()
# this is for the next group of stat
self._stat_holder.add_stat('epoch_num', self.epoch_num + 1)
class SendStat(Triggerable):
......
......@@ -109,7 +109,9 @@ class Trainer(object):
suffix = '-summary' # issue#6150
if val.tag.endswith(suffix):
val.tag = val.tag[:-len(suffix)]
self.stat_holder.add_stat(val.tag, val.simple_value)
self.stat_holder.add_stat(
val.tag, val.simple_value,
self.global_step, self.epoch_num)
self.summary_writer.add_summary(summary, get_global_step_value())
def add_scalar_summary(self, name, val):
......@@ -151,12 +153,19 @@ class Trainer(object):
def _setup(self):
""" setup Trainer-specific stuff for training"""
@property
def global_step(self):
return self._starting_step + \
self.config.steps_per_epoch * (self.epoch_num - 1) + \
self.local_step + 1
def main_loop(self):
"""
Run the main training loop.
"""
callbacks = self.config.callbacks
with self.sess.as_default():
self._starting_step = get_global_step_value()
try:
callbacks.before_train()
for self.epoch_num in range(
......@@ -173,7 +182,7 @@ class Trainer(object):
else:
callbacks.trigger_step(*fetch_data)
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, get_global_step_value(), time.time() - start_time))
self.epoch_num, self.global_step, time.time() - start_time))
# trigger epoch outside the timing region.
self.trigger_epoch()
......
......@@ -133,7 +133,7 @@ def get_tqdm(**kwargs):
def building_rtfd():
"""
Returns:
bool: if tensorpack is imported to generate docs now.
bool: if tensorpack is being imported to generate docs now.
"""
return os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment