Commit efc8a5f0 authored by Yuxin Wu's avatar Yuxin Wu

allow scalar logging between steps

parent ad398637
......@@ -130,9 +130,9 @@ class TFSummaryWriter(TrainingMonitor):
self._writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
def put_summary(self, summary):
self._writer.add_summary(summary, self.trainer.global_step)
self._writer.add_summary(summary, self.global_step)
def _trigger(self):
def _trigger(self): # flush every epoch
self._writer.flush()
def _after_train(self):
......@@ -165,19 +165,25 @@ class JSONWriter(TrainingMonitor):
self._stat_now = {}
self._last_gs = -1
self._total = self.trainer.config.steps_per_epoch
def put_scalar(self, name, val):
gs = self.trainer.global_step
if gs != self._last_gs:
def _trigger_step(self):
# will do this in trigger_epoch
if self.local_step != self._total - 1:
self._push()
self._last_gs = gs
self._stat_now['epoch_num'] = self.trainer.epoch_num
self._stat_now['global_step'] = gs
def _trigger_epoch(self):
self._push()
def put_scalar(self, name, val):
self._stat_now[name] = float(val) # TODO will fail for non-numeric
def _push(self):
""" Note that this method is idempotent"""
if len(self._stat_now):
self._stat_now['epoch_num'] = self.epoch_num
self._stat_now['global_step'] = self.global_step
self._stats.append(self._stat_now)
self._stat_now = {}
self._write_stat()
......@@ -189,23 +195,42 @@ class JSONWriter(TrainingMonitor):
json.dump(self._stats, f)
shutil.move(tmp_filename, self._fname)
except IOError: # disk error sometimes..
logger.exception("Exception in StatHolder.finalize()!")
def _trigger(self):
self._push()
logger.exception("Exception in JSONWriter._write_stat()!")
# TODO print interval
class ScalarPrinter(TrainingMonitor):
"""
Print all scalar data in terminal.
Print scalar data into terminal.
"""
def __init__(self):
def __init__(self, enable_step=False, enable_epoch=True):
"""
Args:
enable_step, enable_epoch (bool): whether to print the
monitor data (if any) between steps or between epochs.
"""
self._whitelist = None
self._blacklist = set([])
self._enable_step = enable_step
self._enable_epoch = enable_epoch
def _setup_graph(self):
self._dic = {}
self._total = self.trainer.config.steps_per_epoch
def _trigger_step(self):
if self._enable_step:
if self.local_step != self._total - 1:
# not the last step
self._print_stat()
else:
if not self._enable_epoch:
self._print_stat()
# otherwise, will print them together
def _trigger_epoch(self):
if self._enable_epoch:
self._print_stat()
def put_scalar(self, name, val):
self._dic[name] = float(val)
......@@ -215,9 +240,6 @@ class ScalarPrinter(TrainingMonitor):
if self._whitelist is None or k in self._whitelist:
if k not in self._blacklist:
logger.info('{}: {:.5g}'.format(k, v))
def _trigger(self):
self._print_stat()
self._dic = {}
......
......@@ -15,7 +15,7 @@ class StatPrinter(Callback):
def __init__(self, print_tag=None):
log_deprecated("StatPrinter",
"No need to add StatPrinter to callbacks anymore!",
"2017-03-26")
"2017-05-26")
# TODO make it into monitor?
......
......@@ -85,7 +85,8 @@ class ProgressBar(Callback):
def __init__(self, names=[]):
"""
Args:
names(list): list of string, the names of the tensors to display.
names(list): list of string, the names of the tensors to monitor
on the progress bar.
"""
super(ProgressBar, self).__init__()
self._names = [get_op_tensor_name(n)[1] for n in names]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment