Commit 1a19939e authored by Yuxin Wu's avatar Yuxin Wu

Replace some fs operations by tf.gfile for other fs support. (fix #416)

parent c5e05d7a
......@@ -5,8 +5,6 @@
import tensorflow as tf
from datetime import datetime
import os
import shutil
import glob
from .base import Callback
from ..utils import logger
......@@ -47,7 +45,8 @@ class ModelSaver(Callback):
if checkpoint_dir is None:
checkpoint_dir = logger.LOG_DIR
assert checkpoint_dir is not None
assert tf.gfile.IsDirectory(checkpoint_dir), checkpoint_dir
if not tf.gfile.IsDirectory(checkpoint_dir):
tf.gfile.MakeDirs(checkpoint_dir)
self.checkpoint_dir = checkpoint_dir
def _setup_graph(self):
......@@ -153,9 +152,9 @@ class MinSaver(Callback):
newname = os.path.join(logger.LOG_DIR,
self.filename or
('max-' + self.monitor_stat if self.reverse else 'min-' + self.monitor_stat))
files_to_copy = glob.glob(path + '*')
files_to_copy = tf.gfile.Glob(path + '*')
for file_to_copy in files_to_copy:
shutil.copy(file_to_copy, file_to_copy.replace(path, newname))
tf.gfile.Copy(file_to_copy, file_to_copy.replace(path, newname), overwrite=True)
logger.info("Model with {} '{}' saved.".format(
'maximum' if self.reverse else 'minimum', self.monitor_stat))
......
......@@ -258,10 +258,10 @@ def get_model_loader(filename):
:class:`SaverRestore` (otherwise).
"""
if filename.endswith('.npy'):
assert os.path.isfile(filename), filename
assert tf.gfile.Exists(filename), filename
return DictRestore(np.load(filename, encoding='latin1').item())
elif filename.endswith('.npz'):
assert os.path.isfile(filename), filename
assert tf.gfile.Exists(filename), filename
obj = np.load(filename)
return DictRestore(dict(obj))
else:
......@@ -278,6 +278,6 @@ def TryResumeTraining():
if not logger.LOG_DIR:
return JustCurrentSession()
path = os.path.join(logger.LOG_DIR, 'checkpoint')
if not os.path.isfile(path):
if not tf.gfile.Exists(path):
return JustCurrentSession()
return SaverRestore(path)
......@@ -148,7 +148,7 @@ def get_checkpoint_path(model_path):
if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint':
assert os.path.isfile(model_path), model_path
assert tf.gfile.Exists(model_path), model_path
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2
......@@ -162,7 +162,7 @@ def get_checkpoint_path(model_path):
logger.warn(
"Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path))
model_path = new_path
assert os.path.isfile(model_path) or os.path.isfile(model_path + '.index'), model_path
assert tf.gfile.Exists(model_path) or tf.gfile.Exists(model_path + '.index'), model_path
return model_path
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment