Commit fff3f2d3 authored by Yuxin Wu's avatar Yuxin Wu

Use `logger.get_logger_dir()` to access logger directory.

parent 1be49dc9
......@@ -231,7 +231,7 @@ class EvalCallback(Callback):
def _eval(self):
all_results = eval_on_dataflow(self.df, lambda img: detect_one_image(img, self.pred))
output_file = os.path.join(
logger.LOG_DIR, 'outputs{}.json'.format(self.global_step))
logger.get_logger_dir(), 'outputs{}.json'.format(self.global_step))
with open(output_file, 'w') as f:
json.dump(all_results, f)
print_evaluation_scores(output_file)
......
......@@ -134,7 +134,8 @@ class ProcessTensors(Callback):
class DumpTensors(ProcessTensors):
"""
Dump some tensors to a file.
Every step this callback fetches tensors and write them to a npz file under ``logger.LOG_DIR``.
Every step this callback fetches tensors and write them to a npz file
under ``logger.get_logger_dir``.
The dump can be loaded by ``dict(np.load(filename).items())``.
"""
def __init__(self, names):
......@@ -144,7 +145,7 @@ class DumpTensors(ProcessTensors):
"""
assert isinstance(names, (list, tuple)), names
self._names = names
dir = logger.LOG_DIR
dir = logger.get_logger_dir()
def fn(*args):
dic = {}
......
......@@ -193,14 +193,14 @@ class TFEventWriter(TrainingMonitor):
Write summaries to TensorFlow event file.
"""
def __new__(cls):
if logger.LOG_DIR:
if logger.get_logger_dir():
return super(TFEventWriter, cls).__new__(cls)
else:
logger.warn("logger directory was not set. Ignore TFEventWriter.")
return NoOpMonitor()
def _setup_graph(self):
self._writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
self._writer = tf.summary.FileWriter(logger.get_logger_dir(), graph=tf.get_default_graph())
def process_summary(self, summary):
self._writer.add_summary(summary, self.global_step)
......@@ -222,7 +222,7 @@ def TFSummaryWriter(*args, **kwargs):
class JSONWriter(TrainingMonitor):
"""
Write all scalar data to a json file under ``logger.LOG_DIR``, grouped by their global step.
Write all scalar data to a json file under ``logger.get_logger_dir()``, grouped by their global step.
This monitor also attemps to recover the epoch number during setup,
if an existing json file is found at the same place.
"""
......@@ -233,14 +233,14 @@ class JSONWriter(TrainingMonitor):
"""
def __new__(cls):
if logger.LOG_DIR:
if logger.get_logger_dir():
return super(JSONWriter, cls).__new__(cls)
else:
logger.warn("logger directory was not set. Ignore JSONWriter.")
return NoOpMonitor()
def _before_train(self):
self._dir = logger.LOG_DIR
self._dir = logger.get_logger_dir()
self._fname = os.path.join(self._dir, self.FILENAME)
if os.path.isfile(self._fname):
......
......@@ -184,7 +184,7 @@ class HumanHyperParamSetter(HyperParamSetter):
If the pair is not found, the param will not be changed.
"""
super(HumanHyperParamSetter, self).__init__(param)
self.file_name = os.path.join(logger.LOG_DIR, file_name)
self.file_name = os.path.join(logger.get_logger_dir(), file_name)
logger.info("Use {} to set hyperparam: '{}'.".format(
self.file_name, self.param.readable_name))
......
......@@ -105,7 +105,7 @@ class GPUUtilizationTracker(Callback):
class GraphProfiler(Callback):
"""
Enable profiling by installing session hooks,
and write metadata or tracing files to ``logger.LOG_DIR``.
and write metadata or tracing files to ``logger.get_logger_dir()``.
The tracing files can be loaded from ``chrome://tracing``.
The metadata files can be processed by
......@@ -125,7 +125,7 @@ class GraphProfiler(Callback):
dump_event(bool): Dump to an event processed by FileWriter and
will be shown in TensorBoard.
"""
self._dir = logger.LOG_DIR
self._dir = logger.get_logger_dir()
self._dump_meta = bool(dump_metadata)
self._dump_tracing = bool(dump_tracing)
self._dump_event = bool(dump_event)
......
......@@ -27,7 +27,7 @@ class ModelSaver(Callback):
"""
Args:
max_to_keep, keep_checkpoint_every_n_hours(int): the same as in ``tf.train.Saver``.
checkpoint_dir (str): Defaults to ``logger.LOG_DIR``.
checkpoint_dir (str): Defaults to ``logger.get_logger_dir()``.
var_collections (str or list of str): collection of the variables (or list of collections) to save.
"""
self._max_to_keep = max_to_keep
......@@ -43,7 +43,7 @@ class ModelSaver(Callback):
var_collections = [var_collections]
self.var_collections = var_collections
if checkpoint_dir is None:
checkpoint_dir = logger.LOG_DIR
checkpoint_dir = logger.get_logger_dir()
assert checkpoint_dir is not None
if not tf.gfile.IsDirectory(checkpoint_dir):
tf.gfile.MakeDirs(checkpoint_dir)
......@@ -115,7 +115,7 @@ class MinSaver(Callback):
Note:
It assumes that :class:`ModelSaver` is used with
``checkpoint_dir=logger.LOG_DIR`` (the default). And it will save
``checkpoint_dir=logger.get_logger_dir()`` (the default). And it will save
the model to that directory as well.
"""
self.monitor_stat = monitor_stat
......@@ -143,13 +143,13 @@ class MinSaver(Callback):
self._save()
def _save(self):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
ckpt = tf.train.get_checkpoint_state(logger.get_logger_dir())
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = ckpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR,
newname = os.path.join(logger.get_logger_dir(),
self.filename or
('max-' + self.monitor_stat if self.reverse else 'min-' + self.monitor_stat))
files_to_copy = tf.gfile.Glob(path + '*')
......
......@@ -71,7 +71,7 @@ class InjectShell(Callback):
class DumpParamAsImage(Callback):
"""
Dump a tensor to image(s) to ``logger.LOG_DIR`` after every epoch.
Dump a tensor to image(s) to ``logger.get_logger_dir()`` after every epoch.
Note that it requires the tensor is directly evaluable, i.e. either inputs
are not its dependency (e.g. the weights of the model), or the inputs are
......@@ -93,7 +93,7 @@ class DumpParamAsImage(Callback):
self.prefix = op_name
else:
self.prefix = prefix
self.log_dir = logger.LOG_DIR
self.log_dir = logger.get_logger_dir()
self.scale = scale
def _before_train(self):
......
......@@ -270,14 +270,14 @@ def get_model_loader(filename):
def TryResumeTraining():
"""
Try loading latest checkpoint from ``logger.LOG_DIR``, only if there is one.
Try loading latest checkpoint from ``logger.get_logger_dir()``, only if there is one.
Returns:
SessInit: either a :class:`JustCurrentSession`, or a :class:`SaverRestore`.
"""
if not logger.LOG_DIR:
if not logger.get_logger_dir():
return JustCurrentSession()
path = os.path.join(logger.LOG_DIR, 'checkpoint')
path = os.path.join(logger.get_logger_dir(), 'checkpoint')
if not tf.gfile.Exists(path):
return JustCurrentSession()
return SaverRestore(path)
......@@ -11,7 +11,7 @@ from datetime import datetime
from six.moves import input
import sys
__all__ = ['set_logger_dir', 'auto_set_dir']
__all__ = ['set_logger_dir', 'auto_set_dir', 'get_logger_dir']
class _MyFormatter(logging.Formatter):
......@@ -80,7 +80,8 @@ def set_logger_dir(dirname, action=None):
Args:
dirname(str): log directory
action(str): an action of ("k","b","d","n","q") to be performed. Will ask user by default.
action(str): an action of ("k","b","d","n","q") to be performed
when the directory exists. Will ask user by default.
"""
global LOG_DIR, _FILE_HANDLER
if _FILE_HANDLER:
......@@ -128,3 +129,12 @@ def auto_set_dir(action=None, name=None):
if name:
auto_dirname += ':%s' % name
set_logger_dir(auto_dirname, action=action)
def get_logger_dir():
"""
Returns:
The logger directory, or None if not set.
The directory is used for general logging, tensorboard events, checkpoints, etc.
"""
return LOG_DIR
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment