Commit fff3f2d3 authored by Yuxin Wu's avatar Yuxin Wu

Use `logger.get_logger_dir()` to access logger directory.

parent 1be49dc9
...@@ -231,7 +231,7 @@ class EvalCallback(Callback): ...@@ -231,7 +231,7 @@ class EvalCallback(Callback):
def _eval(self): def _eval(self):
all_results = eval_on_dataflow(self.df, lambda img: detect_one_image(img, self.pred)) all_results = eval_on_dataflow(self.df, lambda img: detect_one_image(img, self.pred))
output_file = os.path.join( output_file = os.path.join(
logger.LOG_DIR, 'outputs{}.json'.format(self.global_step)) logger.get_logger_dir(), 'outputs{}.json'.format(self.global_step))
with open(output_file, 'w') as f: with open(output_file, 'w') as f:
json.dump(all_results, f) json.dump(all_results, f)
print_evaluation_scores(output_file) print_evaluation_scores(output_file)
......
...@@ -134,7 +134,8 @@ class ProcessTensors(Callback): ...@@ -134,7 +134,8 @@ class ProcessTensors(Callback):
class DumpTensors(ProcessTensors): class DumpTensors(ProcessTensors):
""" """
Dump some tensors to a file. Dump some tensors to a file.
Every step this callback fetches tensors and write them to a npz file under ``logger.LOG_DIR``. Every step this callback fetches tensors and write them to a npz file
under ``logger.get_logger_dir``.
The dump can be loaded by ``dict(np.load(filename).items())``. The dump can be loaded by ``dict(np.load(filename).items())``.
""" """
def __init__(self, names): def __init__(self, names):
...@@ -144,7 +145,7 @@ class DumpTensors(ProcessTensors): ...@@ -144,7 +145,7 @@ class DumpTensors(ProcessTensors):
""" """
assert isinstance(names, (list, tuple)), names assert isinstance(names, (list, tuple)), names
self._names = names self._names = names
dir = logger.LOG_DIR dir = logger.get_logger_dir()
def fn(*args): def fn(*args):
dic = {} dic = {}
......
...@@ -193,14 +193,14 @@ class TFEventWriter(TrainingMonitor): ...@@ -193,14 +193,14 @@ class TFEventWriter(TrainingMonitor):
Write summaries to TensorFlow event file. Write summaries to TensorFlow event file.
""" """
def __new__(cls): def __new__(cls):
if logger.LOG_DIR: if logger.get_logger_dir():
return super(TFEventWriter, cls).__new__(cls) return super(TFEventWriter, cls).__new__(cls)
else: else:
logger.warn("logger directory was not set. Ignore TFEventWriter.") logger.warn("logger directory was not set. Ignore TFEventWriter.")
return NoOpMonitor() return NoOpMonitor()
def _setup_graph(self): def _setup_graph(self):
self._writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph()) self._writer = tf.summary.FileWriter(logger.get_logger_dir(), graph=tf.get_default_graph())
def process_summary(self, summary): def process_summary(self, summary):
self._writer.add_summary(summary, self.global_step) self._writer.add_summary(summary, self.global_step)
...@@ -222,7 +222,7 @@ def TFSummaryWriter(*args, **kwargs): ...@@ -222,7 +222,7 @@ def TFSummaryWriter(*args, **kwargs):
class JSONWriter(TrainingMonitor): class JSONWriter(TrainingMonitor):
""" """
Write all scalar data to a json file under ``logger.LOG_DIR``, grouped by their global step. Write all scalar data to a json file under ``logger.get_logger_dir()``, grouped by their global step.
This monitor also attemps to recover the epoch number during setup, This monitor also attemps to recover the epoch number during setup,
if an existing json file is found at the same place. if an existing json file is found at the same place.
""" """
...@@ -233,14 +233,14 @@ class JSONWriter(TrainingMonitor): ...@@ -233,14 +233,14 @@ class JSONWriter(TrainingMonitor):
""" """
def __new__(cls): def __new__(cls):
if logger.LOG_DIR: if logger.get_logger_dir():
return super(JSONWriter, cls).__new__(cls) return super(JSONWriter, cls).__new__(cls)
else: else:
logger.warn("logger directory was not set. Ignore JSONWriter.") logger.warn("logger directory was not set. Ignore JSONWriter.")
return NoOpMonitor() return NoOpMonitor()
def _before_train(self): def _before_train(self):
self._dir = logger.LOG_DIR self._dir = logger.get_logger_dir()
self._fname = os.path.join(self._dir, self.FILENAME) self._fname = os.path.join(self._dir, self.FILENAME)
if os.path.isfile(self._fname): if os.path.isfile(self._fname):
......
...@@ -184,7 +184,7 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -184,7 +184,7 @@ class HumanHyperParamSetter(HyperParamSetter):
If the pair is not found, the param will not be changed. If the pair is not found, the param will not be changed.
""" """
super(HumanHyperParamSetter, self).__init__(param) super(HumanHyperParamSetter, self).__init__(param)
self.file_name = os.path.join(logger.LOG_DIR, file_name) self.file_name = os.path.join(logger.get_logger_dir(), file_name)
logger.info("Use {} to set hyperparam: '{}'.".format( logger.info("Use {} to set hyperparam: '{}'.".format(
self.file_name, self.param.readable_name)) self.file_name, self.param.readable_name))
......
...@@ -105,7 +105,7 @@ class GPUUtilizationTracker(Callback): ...@@ -105,7 +105,7 @@ class GPUUtilizationTracker(Callback):
class GraphProfiler(Callback): class GraphProfiler(Callback):
""" """
Enable profiling by installing session hooks, Enable profiling by installing session hooks,
and write metadata or tracing files to ``logger.LOG_DIR``. and write metadata or tracing files to ``logger.get_logger_dir()``.
The tracing files can be loaded from ``chrome://tracing``. The tracing files can be loaded from ``chrome://tracing``.
The metadata files can be processed by The metadata files can be processed by
...@@ -125,7 +125,7 @@ class GraphProfiler(Callback): ...@@ -125,7 +125,7 @@ class GraphProfiler(Callback):
dump_event(bool): Dump to an event processed by FileWriter and dump_event(bool): Dump to an event processed by FileWriter and
will be shown in TensorBoard. will be shown in TensorBoard.
""" """
self._dir = logger.LOG_DIR self._dir = logger.get_logger_dir()
self._dump_meta = bool(dump_metadata) self._dump_meta = bool(dump_metadata)
self._dump_tracing = bool(dump_tracing) self._dump_tracing = bool(dump_tracing)
self._dump_event = bool(dump_event) self._dump_event = bool(dump_event)
......
...@@ -27,7 +27,7 @@ class ModelSaver(Callback): ...@@ -27,7 +27,7 @@ class ModelSaver(Callback):
""" """
Args: Args:
max_to_keep, keep_checkpoint_every_n_hours(int): the same as in ``tf.train.Saver``. max_to_keep, keep_checkpoint_every_n_hours(int): the same as in ``tf.train.Saver``.
checkpoint_dir (str): Defaults to ``logger.LOG_DIR``. checkpoint_dir (str): Defaults to ``logger.get_logger_dir()``.
var_collections (str or list of str): collection of the variables (or list of collections) to save. var_collections (str or list of str): collection of the variables (or list of collections) to save.
""" """
self._max_to_keep = max_to_keep self._max_to_keep = max_to_keep
...@@ -43,7 +43,7 @@ class ModelSaver(Callback): ...@@ -43,7 +43,7 @@ class ModelSaver(Callback):
var_collections = [var_collections] var_collections = [var_collections]
self.var_collections = var_collections self.var_collections = var_collections
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = logger.LOG_DIR checkpoint_dir = logger.get_logger_dir()
assert checkpoint_dir is not None assert checkpoint_dir is not None
if not tf.gfile.IsDirectory(checkpoint_dir): if not tf.gfile.IsDirectory(checkpoint_dir):
tf.gfile.MakeDirs(checkpoint_dir) tf.gfile.MakeDirs(checkpoint_dir)
...@@ -115,7 +115,7 @@ class MinSaver(Callback): ...@@ -115,7 +115,7 @@ class MinSaver(Callback):
Note: Note:
It assumes that :class:`ModelSaver` is used with It assumes that :class:`ModelSaver` is used with
``checkpoint_dir=logger.LOG_DIR`` (the default). And it will save ``checkpoint_dir=logger.get_logger_dir()`` (the default). And it will save
the model to that directory as well. the model to that directory as well.
""" """
self.monitor_stat = monitor_stat self.monitor_stat = monitor_stat
...@@ -143,13 +143,13 @@ class MinSaver(Callback): ...@@ -143,13 +143,13 @@ class MinSaver(Callback):
self._save() self._save()
def _save(self): def _save(self):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR) ckpt = tf.train.get_checkpoint_state(logger.get_logger_dir())
if ckpt is None: if ckpt is None:
raise RuntimeError( raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use ModelSaver?") "Cannot find a checkpoint state. Do you forget to use ModelSaver?")
path = ckpt.model_checkpoint_path path = ckpt.model_checkpoint_path
newname = os.path.join(logger.LOG_DIR, newname = os.path.join(logger.get_logger_dir(),
self.filename or self.filename or
('max-' + self.monitor_stat if self.reverse else 'min-' + self.monitor_stat)) ('max-' + self.monitor_stat if self.reverse else 'min-' + self.monitor_stat))
files_to_copy = tf.gfile.Glob(path + '*') files_to_copy = tf.gfile.Glob(path + '*')
......
...@@ -71,7 +71,7 @@ class InjectShell(Callback): ...@@ -71,7 +71,7 @@ class InjectShell(Callback):
class DumpParamAsImage(Callback): class DumpParamAsImage(Callback):
""" """
Dump a tensor to image(s) to ``logger.LOG_DIR`` after every epoch. Dump a tensor to image(s) to ``logger.get_logger_dir()`` after every epoch.
Note that it requires the tensor is directly evaluable, i.e. either inputs Note that it requires the tensor is directly evaluable, i.e. either inputs
are not its dependency (e.g. the weights of the model), or the inputs are are not its dependency (e.g. the weights of the model), or the inputs are
...@@ -93,7 +93,7 @@ class DumpParamAsImage(Callback): ...@@ -93,7 +93,7 @@ class DumpParamAsImage(Callback):
self.prefix = op_name self.prefix = op_name
else: else:
self.prefix = prefix self.prefix = prefix
self.log_dir = logger.LOG_DIR self.log_dir = logger.get_logger_dir()
self.scale = scale self.scale = scale
def _before_train(self): def _before_train(self):
......
...@@ -270,14 +270,14 @@ def get_model_loader(filename): ...@@ -270,14 +270,14 @@ def get_model_loader(filename):
def TryResumeTraining(): def TryResumeTraining():
""" """
Try loading latest checkpoint from ``logger.LOG_DIR``, only if there is one. Try loading latest checkpoint from ``logger.get_logger_dir()``, only if there is one.
Returns: Returns:
SessInit: either a :class:`JustCurrentSession`, or a :class:`SaverRestore`. SessInit: either a :class:`JustCurrentSession`, or a :class:`SaverRestore`.
""" """
if not logger.LOG_DIR: if not logger.get_logger_dir():
return JustCurrentSession() return JustCurrentSession()
path = os.path.join(logger.LOG_DIR, 'checkpoint') path = os.path.join(logger.get_logger_dir(), 'checkpoint')
if not tf.gfile.Exists(path): if not tf.gfile.Exists(path):
return JustCurrentSession() return JustCurrentSession()
return SaverRestore(path) return SaverRestore(path)
...@@ -11,7 +11,7 @@ from datetime import datetime ...@@ -11,7 +11,7 @@ from datetime import datetime
from six.moves import input from six.moves import input
import sys import sys
__all__ = ['set_logger_dir', 'auto_set_dir'] __all__ = ['set_logger_dir', 'auto_set_dir', 'get_logger_dir']
class _MyFormatter(logging.Formatter): class _MyFormatter(logging.Formatter):
...@@ -80,7 +80,8 @@ def set_logger_dir(dirname, action=None): ...@@ -80,7 +80,8 @@ def set_logger_dir(dirname, action=None):
Args: Args:
dirname(str): log directory dirname(str): log directory
action(str): an action of ("k","b","d","n","q") to be performed. Will ask user by default. action(str): an action of ("k","b","d","n","q") to be performed
when the directory exists. Will ask user by default.
""" """
global LOG_DIR, _FILE_HANDLER global LOG_DIR, _FILE_HANDLER
if _FILE_HANDLER: if _FILE_HANDLER:
...@@ -128,3 +129,12 @@ def auto_set_dir(action=None, name=None): ...@@ -128,3 +129,12 @@ def auto_set_dir(action=None, name=None):
if name: if name:
auto_dirname += ':%s' % name auto_dirname += ':%s' % name
set_logger_dir(auto_dirname, action=action) set_logger_dir(auto_dirname, action=action)
def get_logger_dir():
"""
Returns:
The logger directory, or None if not set.
The directory is used for general logging, tensorboard events, checkpoints, etc.
"""
return LOG_DIR
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment