Commit df74a4a9 authored by Yuxin Wu's avatar Yuxin Wu

stat change in trainer

parent 6745a4e4
...@@ -48,6 +48,7 @@ class StatHolder(object): ...@@ -48,6 +48,7 @@ class StatHolder(object):
self.print_tag = None if print_tag is None else set(print_tag) self.print_tag = None if print_tag is None else set(print_tag)
def add_blacklist_tag(self, blacklist_tag): def add_blacklist_tag(self, blacklist_tag):
""" Disable printing for some tags """
self.blacklist_tag |= set(blacklist_tag) self.blacklist_tag |= set(blacklist_tag)
def get_stat_now(self, key): def get_stat_now(self, key):
...@@ -67,7 +68,7 @@ class StatHolder(object): ...@@ -67,7 +68,7 @@ class StatHolder(object):
def finalize(self): def finalize(self):
""" """
Called after finishing adding stats. Will print and write stats to disk. Called after finishing adding stats for this epoch. Will print and write stats to disk.
""" """
self._print_stat() self._print_stat()
self.stat_history.append(self.stat_now) self.stat_history.append(self.stat_now)
...@@ -102,9 +103,9 @@ class StatPrinter(Callback): ...@@ -102,9 +103,9 @@ class StatPrinter(Callback):
def _before_train(self): def _before_train(self):
self.trainer.stat_holder.set_print_tag(self.print_tag) self.trainer.stat_holder.set_print_tag(self.print_tag)
self.trainer.stat_holder.add_blacklist_tag(['global_step', 'epoch_num'])
def _trigger_epoch(self): def _trigger_epoch(self):
self.trainer.stat_holder.add_stat('global_step', self.global_step)
self.trainer.stat_holder.finalize() self.trainer.stat_holder.finalize()
class SendStat(Callback): class SendStat(Callback):
......
...@@ -65,7 +65,13 @@ class Trainer(object): ...@@ -65,7 +65,13 @@ class Trainer(object):
return [self.get_predict_func(input_names, output_names) for k in range(n)] return [self.get_predict_func(input_names, output_names) for k in range(n)]
def trigger_epoch(self): def trigger_epoch(self):
# by default, add this two stat
self.stat_holder.add_stat('global_step', self.global_step)
self.stat_holder.add_stat('epoch_num', self.epoch_num)
# trigger subclass
self._trigger_epoch() self._trigger_epoch()
# trigger callbacks
self.config.callbacks.trigger_epoch() self.config.callbacks.trigger_epoch()
self.summary_writer.flush() self.summary_writer.flush()
...@@ -82,8 +88,6 @@ class Trainer(object): ...@@ -82,8 +88,6 @@ class Trainer(object):
self.summary_op = tf.merge_all_summaries() self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder # create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR) self.stat_holder = StatHolder(logger.LOG_DIR)
# save global_step in stat.json, but don't print it
self.stat_holder.add_blacklist_tag(['global_step'])
def _process_summary(self, summary_str): def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str) summary = tf.Summary.FromString(summary_str)
...@@ -118,23 +122,23 @@ class Trainer(object): ...@@ -118,23 +122,23 @@ class Trainer(object):
logger.info("Start training with global_step={}".format(self.global_step)) logger.info("Start training with global_step={}".format(self.global_step))
callbacks.before_train() callbacks.before_train()
for epoch in range(self.config.starting_epoch, self.config.max_epoch+1): for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation( with timed_operation(
'Epoch {}, global_step={}'.format( 'Epoch {}, global_step={}'.format(
epoch, self.global_step + self.config.step_per_epoch)): self.epoch_num, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange( for step in tqdm.trange(
self.config.step_per_epoch, self.config.step_per_epoch,
**get_tqdm_kwargs(leave=True)): **get_tqdm_kwargs(leave=True)):
if self.coord.should_stop(): if self.coord.should_stop():
return return
self.run_step() self.run_step() # implemented by subclass
#callbacks.trigger_step() # not useful? #callbacks.trigger_step() # not useful?
self.global_step += 1 self.global_step += 1
self.trigger_epoch() self.trigger_epoch()
except (KeyboardInterrupt, Exception): except (KeyboardInterrupt, Exception):
raise raise
finally: finally:
# Do I need to run queue.close?
callbacks.after_train() callbacks.after_train()
self.coord.request_stop() self.coord.request_stop()
self.summary_writer.close() self.summary_writer.close()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment