Commit 67786cbb authored by Yuxin Wu's avatar Yuxin Wu

latest model link

parent 771a8b0f
...@@ -61,7 +61,17 @@ class ModelSaver(Callback): ...@@ -61,7 +61,17 @@ class ModelSaver(Callback):
self.path, self.path,
global_step=self.global_step, global_step=self.global_step,
write_meta_graph=False) write_meta_graph=False)
except Exception: # disk error sometimes..
# create a symbolic link for the latest model
latest = self.saver.last_checkpoints[-1]
basename = os.path.basename(latest)
linkname = os.path.join(os.path.dirname(latest), 'latest')
try:
os.unlink(linkname)
except FileNotFoundError:
pass
os.symlink(basename, linkname)
except Exception: # disk error sometimes.. just ignore
logger.exception("Exception in ModelSaver.trigger_epoch!") logger.exception("Exception in ModelSaver.trigger_epoch!")
class MinSaver(Callback): class MinSaver(Callback):
......
...@@ -22,6 +22,7 @@ class StatHolder(object): ...@@ -22,6 +22,7 @@ class StatHolder(object):
:param log_dir: directory to save the stats. :param log_dir: directory to save the stats.
""" """
self.set_print_tag([]) self.set_print_tag([])
self.blacklist_tag = set()
self.stat_now = {} self.stat_now = {}
self.log_dir = log_dir self.log_dir = log_dir
...@@ -47,6 +48,9 @@ class StatHolder(object): ...@@ -47,6 +48,9 @@ class StatHolder(object):
""" """
self.print_tag = None if print_tag is None else set(print_tag) self.print_tag = None if print_tag is None else set(print_tag)
def add_blacklist_tag(self, blacklist_tag):
self.blacklist_tag |= set(blacklist_tag)
def get_stat_now(self, key): def get_stat_now(self, key):
""" """
Return the value of a stat in the current epoch. Return the value of a stat in the current epoch.
...@@ -65,6 +69,7 @@ class StatHolder(object): ...@@ -65,6 +69,7 @@ class StatHolder(object):
def _print_stat(self): def _print_stat(self):
for k, v in sorted(self.stat_now.items(), key=operator.itemgetter(0)): for k, v in sorted(self.stat_now.items(), key=operator.itemgetter(0)):
if self.print_tag is None or k in self.print_tag: if self.print_tag is None or k in self.print_tag:
if k not in self.blacklist_tag:
logger.info('{}: {:.5g}'.format(k, v)) logger.info('{}: {:.5g}'.format(k, v))
def _write_stat(self): def _write_stat(self):
......
...@@ -68,6 +68,7 @@ class Trainer(object): ...@@ -68,6 +68,7 @@ class Trainer(object):
self._trigger_epoch() self._trigger_epoch()
self.config.callbacks.trigger_epoch() self.config.callbacks.trigger_epoch()
self.summary_writer.flush() self.summary_writer.flush()
self.stat_holder.add_stat('global_step', self.global_step)
self.stat_holder.finalize() self.stat_holder.finalize()
@abstractmethod @abstractmethod
...@@ -83,6 +84,8 @@ class Trainer(object): ...@@ -83,6 +84,8 @@ class Trainer(object):
self.summary_op = tf.merge_all_summaries() self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder # create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR) self.stat_holder = StatHolder(logger.LOG_DIR)
# save global_step in stat.json, but don't print it
self.stat_holder.add_blacklist_tag(['global_step'])
def _process_summary(self, summary_str): def _process_summary(self, summary_str):
summary = tf.Summary.FromString(summary_str) summary = tf.Summary.FromString(summary_str)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment