Commit 78762b38 authored by Dan Anghel's avatar Dan Anghel Committed by Yuxin Wu

Save checkpoints and event files on Google Cloud Storage (#1316)

* Fix to be able to save checkpoints and event files on Google Cloud Storage

* Created custom normpath() function to handle the case of remote Cloud storages
parent e68eec29
...@@ -17,7 +17,7 @@ import threading ...@@ -17,7 +17,7 @@ import threading
from ..compat import tfv1 as tf from ..compat import tfv1 as tf
from ..libinfo import __git_version__ from ..libinfo import __git_version__
from ..tfutils.summary import create_image_summary, create_scalar_summary from ..tfutils.summary import create_image_summary, create_scalar_summary
from ..utils import logger from ..utils import fs, logger
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC
from .base import Callback from .base import Callback
...@@ -239,7 +239,7 @@ class TFEventWriter(MonitorBase): ...@@ -239,7 +239,7 @@ class TFEventWriter(MonitorBase):
if logdir is None: if logdir is None:
logdir = logger.get_logger_dir() logdir = logger.get_logger_dir()
assert tf.gfile.IsDirectory(logdir), logdir assert tf.gfile.IsDirectory(logdir), logdir
self._logdir = os.path.normpath(logdir) self._logdir = fs.normpath(logdir)
self._max_queue = max_queue self._max_queue = max_queue
self._flush_secs = flush_secs self._flush_secs = flush_secs
self._split_files = split_files self._split_files = split_files
......
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
from datetime import datetime from datetime import datetime
from ..compat import tfv1 as tf from ..compat import tfv1 as tf
from ..utils import logger from ..utils import fs, logger
from .base import Callback from .base import Callback
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver'] __all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
...@@ -45,7 +45,7 @@ class ModelSaver(Callback): ...@@ -45,7 +45,7 @@ class ModelSaver(Callback):
# If None, allow it to be init, but fail later if used # If None, allow it to be init, but fail later if used
# For example, if chief_only=True, it can still be safely initialized # For example, if chief_only=True, it can still be safely initialized
# in non-chief workers which don't have logger dir # in non-chief workers which don't have logger dir
self.checkpoint_dir = os.path.normpath(checkpoint_dir) if checkpoint_dir is not None else checkpoint_dir self.checkpoint_dir = fs.normpath(checkpoint_dir) if checkpoint_dir is not None else checkpoint_dir
def _setup_graph(self): def _setup_graph(self):
assert self.checkpoint_dir is not None, \ assert self.checkpoint_dir is not None, \
...@@ -121,7 +121,7 @@ class MinSaver(Callback): ...@@ -121,7 +121,7 @@ class MinSaver(Callback):
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
if self.checkpoint_dir is None: if self.checkpoint_dir is None:
self.checkpoint_dir = logger.get_logger_dir() self.checkpoint_dir = logger.get_logger_dir()
self.checkpoint_dir = os.path.normpath(self.checkpoint_dir) self.checkpoint_dir = fs.normpath(self.checkpoint_dir)
def _get_stat(self): def _get_stat(self):
try: try:
......
...@@ -10,7 +10,7 @@ from six.moves import urllib ...@@ -10,7 +10,7 @@ from six.moves import urllib
from . import logger from . import logger
from .utils import execute_only_once from .utils import execute_only_once
__all__ = ['mkdir_p', 'download', 'recursive_walk', 'get_dataset_path'] __all__ = ['mkdir_p', 'download', 'recursive_walk', 'get_dataset_path', 'normpath']
def mkdir_p(dirname): def mkdir_p(dirname):
...@@ -106,5 +106,19 @@ def get_dataset_path(*args): ...@@ -106,5 +106,19 @@ def get_dataset_path(*args):
return os.path.join(d, *args) return os.path.join(d, *args)
def normpath(path):
"""
Normalizes a path to a folder by taking into consideration remote storages like Cloud storaged
referenced by '://' at the beginning of the path.
Args:
args: path to be normalized.
Returns:
str: normalized path.
"""
return path if '://' in path else os.path.normpath(path)
if __name__ == '__main__': if __name__ == '__main__':
download('http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz', '.') download('http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz', '.')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment