Commit 000b2ea3 authored by Yuxin Wu's avatar Yuxin Wu

monitor becomes no-op when logger is not set

parent 0f4eda16
......@@ -34,6 +34,7 @@ class ModelSaver(Triggerable):
self.var_collections = var_collections
if checkpoint_dir is None:
checkpoint_dir = logger.LOG_DIR
assert os.path.isdir(checkpoint_dir), checkpoint_dir
self.checkpoint_dir = checkpoint_dir
def _setup_graph(self):
......
......@@ -43,6 +43,7 @@ class Trainer(object):
epoch_num (int): the number of epochs that have finished.
local_step (int): the number of steps that have finished in the current epoch.
global_step (int): the number of steps that have finished.
"""
def __init__(self, config):
......@@ -59,10 +60,12 @@ class Trainer(object):
self._callbacks = []
self.register_callback(MaintainStepCounter())
for cb in self.config.callbacks:
for cb in config.callbacks:
self.register_callback(cb)
self.monitors = config.monitors
self.monitors = []
for m in config.monitors:
self.register_monitor(m)
def register_callback(self, cb):
"""
......@@ -100,9 +103,6 @@ class Trainer(object):
"""
Setup the trainer and be ready for the main loop.
"""
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("logger directory wasn't set!")
self._setup() # subclass will setup the graph
describe_model()
......
......@@ -18,8 +18,16 @@ __all__ = ['TrainingMonitor', 'Monitors',
class TrainingMonitor(object):
"""
Monitor a training progress, by processing different types of
summary/statistics from trainer.
"""
def setup(self, trainer):
self._trainer = trainer
self._setup()
def _setup(self):
pass
def put_summary(self, summary):
pass
......@@ -34,6 +42,10 @@ class TrainingMonitor(object):
pass
class NoOpMonitor(TrainingMonitor):
pass
class Monitors(TrainingMonitor):
def __init__(self, monitors):
# TODO filter by names
......@@ -92,6 +104,13 @@ class Monitors(TrainingMonitor):
class TFSummaryWriter(TrainingMonitor):
def __new__(cls):
if logger.LOG_DIR:
return super(TFSummaryWriter, cls).__new__(cls)
else:
logger.warn("logger directory was not set. Ignore TFSummaryWriter.")
return NoOpMonitor()
def setup(self, trainer):
super(TFSummaryWriter, self).setup(trainer)
self._writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
......@@ -107,6 +126,13 @@ class TFSummaryWriter(TrainingMonitor):
class JSONWriter(TrainingMonitor):
def __new__(cls):
if logger.LOG_DIR:
return super(JSONWriter, cls).__new__(cls)
else:
logger.warn("logger directory was not set. Ignore JSONWriter.")
return NoOpMonitor()
def setup(self, trainer):
super(JSONWriter, self).setup(trainer)
self._dir = logger.LOG_DIR
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment