Commit 1a19939e authored by Yuxin Wu's avatar Yuxin Wu

Replace some fs operations by tf.gfile for other fs support. (fix #416)

parent c5e05d7a
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
import tensorflow as tf import tensorflow as tf
from datetime import datetime from datetime import datetime
import os import os
import shutil
import glob
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
...@@ -47,7 +45,8 @@ class ModelSaver(Callback): ...@@ -47,7 +45,8 @@ class ModelSaver(Callback):
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = logger.LOG_DIR checkpoint_dir = logger.LOG_DIR
assert checkpoint_dir is not None assert checkpoint_dir is not None
assert tf.gfile.IsDirectory(checkpoint_dir), checkpoint_dir if not tf.gfile.IsDirectory(checkpoint_dir):
tf.gfile.MakeDirs(checkpoint_dir)
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
def _setup_graph(self): def _setup_graph(self):
...@@ -153,9 +152,9 @@ class MinSaver(Callback): ...@@ -153,9 +152,9 @@ class MinSaver(Callback):
newname = os.path.join(logger.LOG_DIR, newname = os.path.join(logger.LOG_DIR,
self.filename or self.filename or
('max-' + self.monitor_stat if self.reverse else 'min-' + self.monitor_stat)) ('max-' + self.monitor_stat if self.reverse else 'min-' + self.monitor_stat))
files_to_copy = glob.glob(path + '*') files_to_copy = tf.gfile.Glob(path + '*')
for file_to_copy in files_to_copy: for file_to_copy in files_to_copy:
shutil.copy(file_to_copy, file_to_copy.replace(path, newname)) tf.gfile.Copy(file_to_copy, file_to_copy.replace(path, newname), overwrite=True)
logger.info("Model with {} '{}' saved.".format( logger.info("Model with {} '{}' saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat)) 'maximum' if self.reverse else 'minimum', self.monitor_stat))
......
...@@ -258,10 +258,10 @@ def get_model_loader(filename): ...@@ -258,10 +258,10 @@ def get_model_loader(filename):
:class:`SaverRestore` (otherwise). :class:`SaverRestore` (otherwise).
""" """
if filename.endswith('.npy'): if filename.endswith('.npy'):
assert os.path.isfile(filename), filename assert tf.gfile.Exists(filename), filename
return DictRestore(np.load(filename, encoding='latin1').item()) return DictRestore(np.load(filename, encoding='latin1').item())
elif filename.endswith('.npz'): elif filename.endswith('.npz'):
assert os.path.isfile(filename), filename assert tf.gfile.Exists(filename), filename
obj = np.load(filename) obj = np.load(filename)
return DictRestore(dict(obj)) return DictRestore(dict(obj))
else: else:
...@@ -278,6 +278,6 @@ def TryResumeTraining(): ...@@ -278,6 +278,6 @@ def TryResumeTraining():
if not logger.LOG_DIR: if not logger.LOG_DIR:
return JustCurrentSession() return JustCurrentSession()
path = os.path.join(logger.LOG_DIR, 'checkpoint') path = os.path.join(logger.LOG_DIR, 'checkpoint')
if not os.path.isfile(path): if not tf.gfile.Exists(path):
return JustCurrentSession() return JustCurrentSession()
return SaverRestore(path) return SaverRestore(path)
...@@ -148,7 +148,7 @@ def get_checkpoint_path(model_path): ...@@ -148,7 +148,7 @@ def get_checkpoint_path(model_path):
if os.path.basename(model_path) == model_path: if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921 and #6142 model_path = os.path.join('.', model_path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint': if os.path.basename(model_path) == 'checkpoint':
assert os.path.isfile(model_path), model_path assert tf.gfile.Exists(model_path), model_path
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path)) model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2 # to be consistent with either v1 or v2
...@@ -162,7 +162,7 @@ def get_checkpoint_path(model_path): ...@@ -162,7 +162,7 @@ def get_checkpoint_path(model_path):
logger.warn( logger.warn(
"Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path)) "Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path))
model_path = new_path model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path assert tf.gfile.Exists(model_path) or tf.gfile.Exists(model_path + '.index'), model_path
return model_path return model_path
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment