Commit df74a4a9 authored by Yuxin Wu's avatar Yuxin Wu

stat change in trainer

parent 6745a4e4
......@@ -48,6 +48,7 @@ class StatHolder(object):
self.print_tag = None if print_tag is None else set(print_tag)
def add_blacklist_tag(self, blacklist_tag):
""" Disable printing for some tags """
self.blacklist_tag |= set(blacklist_tag)
def get_stat_now(self, key):
......@@ -67,7 +68,7 @@ class StatHolder(object):
def finalize(self):
"""
Called after finishing adding stats. Will print and write stats to disk.
Called after finishing adding stats for this epoch. Will print and write stats to disk.
"""
self._print_stat()
self.stat_history.append(self.stat_now)
......@@ -102,9 +103,9 @@ class StatPrinter(Callback):
def _before_train(self):
self.trainer.stat_holder.set_print_tag(self.print_tag)
self.trainer.stat_holder.add_blacklist_tag(['global_step', 'epoch_num'])
def _trigger_epoch(self):
self.trainer.stat_holder.add_stat('global_step', self.global_step)
self.trainer.stat_holder.finalize()
class SendStat(Callback):
......
......@@ -65,7 +65,13 @@ class Trainer(object):
return [self.get_predict_func(input_names, output_names) for k in range(n)]
def trigger_epoch(self):
# by default, add this two stat
self.stat_holder.add_stat('global_step', self.global_step)
self.stat_holder.add_stat('epoch_num', self.epoch_num)
# trigger subclass
self._trigger_epoch()
# trigger callbacks
self.config.callbacks.trigger_epoch()
self.summary_writer.flush()
......@@ -82,8 +88,6 @@ class Trainer(object):
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR)
# save global_step in stat.json, but don't print it
self.stat_holder.add_blacklist_tag(['global_step'])
def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str)
......@@ -118,23 +122,23 @@ class Trainer(object):
logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train()
for epoch in range(self.config.starting_epoch, self.config.max_epoch+1):
for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation(
'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)):
self.epoch_num, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
**get_tqdm_kwargs(leave=True)):
if self.coord.should_stop():
return
self.run_step()
self.run_step() # implemented by subclass
#callbacks.trigger_step() # not useful?
self.global_step += 1
self.trigger_epoch()
except (KeyboardInterrupt, Exception):
raise
finally:
# Do I need to run queue.close?
callbacks.after_train()
self.coord.request_stop()
self.summary_writer.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment