Commit 78762b38 authored by Dan Anghel's avatar Dan Anghel Committed by Yuxin Wu

Save checkpoints and event files on Google Cloud Storage (#1316)

* Fix to be able to save checkpoints and event files on Google Cloud Storage

* Created custom normpath() function to handle the case of remote Cloud storages
parent e68eec29
......@@ -17,7 +17,7 @@ import threading
from ..compat import tfv1 as tf
from ..libinfo import __git_version__
from ..tfutils.summary import create_image_summary, create_scalar_summary
from ..utils import logger
from ..utils import fs, logger
from ..utils.develop import HIDE_DOC
from .base import Callback
......@@ -239,7 +239,7 @@ class TFEventWriter(MonitorBase):
if logdir is None:
logdir = logger.get_logger_dir()
assert tf.gfile.IsDirectory(logdir), logdir
self._logdir = os.path.normpath(logdir)
self._logdir = fs.normpath(logdir)
self._max_queue = max_queue
self._flush_secs = flush_secs
self._split_files = split_files
......
......@@ -6,7 +6,7 @@ import os
from datetime import datetime
from ..compat import tfv1 as tf
from ..utils import logger
from ..utils import fs, logger
from .base import Callback
__all__ = ['ModelSaver', 'MinSaver', 'MaxSaver']
......@@ -45,7 +45,7 @@ class ModelSaver(Callback):
# If None, allow it to be init, but fail later if used
# For example, if chief_only=True, it can still be safely initialized
# in non-chief workers which don't have logger dir
self.checkpoint_dir = os.path.normpath(checkpoint_dir) if checkpoint_dir is not None else checkpoint_dir
self.checkpoint_dir = fs.normpath(checkpoint_dir) if checkpoint_dir is not None else checkpoint_dir
def _setup_graph(self):
assert self.checkpoint_dir is not None, \
......@@ -121,7 +121,7 @@ class MinSaver(Callback):
self.checkpoint_dir = checkpoint_dir
if self.checkpoint_dir is None:
self.checkpoint_dir = logger.get_logger_dir()
self.checkpoint_dir = os.path.normpath(self.checkpoint_dir)
self.checkpoint_dir = fs.normpath(self.checkpoint_dir)
def _get_stat(self):
try:
......
......@@ -10,7 +10,7 @@ from six.moves import urllib
from . import logger
from .utils import execute_only_once
__all__ = ['mkdir_p', 'download', 'recursive_walk', 'get_dataset_path']
__all__ = ['mkdir_p', 'download', 'recursive_walk', 'get_dataset_path', 'normpath']
def mkdir_p(dirname):
......@@ -106,5 +106,19 @@ def get_dataset_path(*args):
return os.path.join(d, *args)
def normpath(path):
"""
Normalizes a path to a folder by taking into consideration remote storages like Cloud storaged
referenced by '://' at the beginning of the path.
Args:
args: path to be normalized.
Returns:
str: normalized path.
"""
return path if '://' in path else os.path.normpath(path)
if __name__ == '__main__':
download('http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz', '.')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment