Commit efc8a5f0 authored by Yuxin Wu's avatar Yuxin Wu

allow scalar logging between steps

parent ad398637
...@@ -130,9 +130,9 @@ class TFSummaryWriter(TrainingMonitor): ...@@ -130,9 +130,9 @@ class TFSummaryWriter(TrainingMonitor):
self._writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph()) self._writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
def put_summary(self, summary): def put_summary(self, summary):
self._writer.add_summary(summary, self.trainer.global_step) self._writer.add_summary(summary, self.global_step)
def _trigger(self): def _trigger(self): # flush every epoch
self._writer.flush() self._writer.flush()
def _after_train(self): def _after_train(self):
...@@ -165,19 +165,25 @@ class JSONWriter(TrainingMonitor): ...@@ -165,19 +165,25 @@ class JSONWriter(TrainingMonitor):
self._stat_now = {} self._stat_now = {}
self._last_gs = -1 self._last_gs = -1
self._total = self.trainer.config.steps_per_epoch
def put_scalar(self, name, val): def _trigger_step(self):
gs = self.trainer.global_step # will do this in trigger_epoch
if gs != self._last_gs: if self.local_step != self._total - 1:
self._push() self._push()
self._last_gs = gs
self._stat_now['epoch_num'] = self.trainer.epoch_num def _trigger_epoch(self):
self._stat_now['global_step'] = gs self._push()
def put_scalar(self, name, val):
self._stat_now[name] = float(val) # TODO will fail for non-numeric self._stat_now[name] = float(val) # TODO will fail for non-numeric
def _push(self): def _push(self):
""" Note that this method is idempotent""" """ Note that this method is idempotent"""
if len(self._stat_now): if len(self._stat_now):
self._stat_now['epoch_num'] = self.epoch_num
self._stat_now['global_step'] = self.global_step
self._stats.append(self._stat_now) self._stats.append(self._stat_now)
self._stat_now = {} self._stat_now = {}
self._write_stat() self._write_stat()
...@@ -189,23 +195,42 @@ class JSONWriter(TrainingMonitor): ...@@ -189,23 +195,42 @@ class JSONWriter(TrainingMonitor):
json.dump(self._stats, f) json.dump(self._stats, f)
shutil.move(tmp_filename, self._fname) shutil.move(tmp_filename, self._fname)
except IOError: # disk error sometimes.. except IOError: # disk error sometimes..
logger.exception("Exception in StatHolder.finalize()!") logger.exception("Exception in JSONWriter._write_stat()!")
def _trigger(self):
self._push()
# TODO print interval
class ScalarPrinter(TrainingMonitor): class ScalarPrinter(TrainingMonitor):
""" """
Print all scalar data in terminal. Print scalar data into terminal.
""" """
def __init__(self): def __init__(self, enable_step=False, enable_epoch=True):
"""
Args:
enable_step, enable_epoch (bool): whether to print the
monitor data (if any) between steps or between epochs.
"""
self._whitelist = None self._whitelist = None
self._blacklist = set([]) self._blacklist = set([])
self._enable_step = enable_step
self._enable_epoch = enable_epoch
def _setup_graph(self): def _setup_graph(self):
self._dic = {} self._dic = {}
self._total = self.trainer.config.steps_per_epoch
def _trigger_step(self):
if self._enable_step:
if self.local_step != self._total - 1:
# not the last step
self._print_stat()
else:
if not self._enable_epoch:
self._print_stat()
# otherwise, will print them together
def _trigger_epoch(self):
if self._enable_epoch:
self._print_stat()
def put_scalar(self, name, val): def put_scalar(self, name, val):
self._dic[name] = float(val) self._dic[name] = float(val)
...@@ -215,9 +240,6 @@ class ScalarPrinter(TrainingMonitor): ...@@ -215,9 +240,6 @@ class ScalarPrinter(TrainingMonitor):
if self._whitelist is None or k in self._whitelist: if self._whitelist is None or k in self._whitelist:
if k not in self._blacklist: if k not in self._blacklist:
logger.info('{}: {:.5g}'.format(k, v)) logger.info('{}: {:.5g}'.format(k, v))
def _trigger(self):
self._print_stat()
self._dic = {} self._dic = {}
......
...@@ -15,7 +15,7 @@ class StatPrinter(Callback): ...@@ -15,7 +15,7 @@ class StatPrinter(Callback):
def __init__(self, print_tag=None): def __init__(self, print_tag=None):
log_deprecated("StatPrinter", log_deprecated("StatPrinter",
"No need to add StatPrinter to callbacks anymore!", "No need to add StatPrinter to callbacks anymore!",
"2017-03-26") "2017-05-26")
# TODO make it into monitor? # TODO make it into monitor?
......
...@@ -85,7 +85,8 @@ class ProgressBar(Callback): ...@@ -85,7 +85,8 @@ class ProgressBar(Callback):
def __init__(self, names=[]): def __init__(self, names=[]):
""" """
Args: Args:
names(list): list of string, the names of the tensors to display. names(list): list of string, the names of the tensors to monitor
on the progress bar.
""" """
super(ProgressBar, self).__init__() super(ProgressBar, self).__init__()
self._names = [get_op_tensor_name(n)[1] for n in names] self._names = [get_op_tensor_name(n)[1] for n in names]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment