Commit aa7e18fc authored by Yuxin Wu's avatar Yuxin Wu

sessinit

parent acb441ca
......@@ -37,6 +37,10 @@ def _import_external_ops(message):
else:
from tensorflow.python.ops import gen_nccl_ops # noqa
return
if 'ZMQConnection' in message:
import zmq_ops
return
logger.error("Unhandled error: " + message)
def guess_inputs(input_dir):
......
......@@ -306,12 +306,14 @@ class JSONWriter(MonitorBase):
return NoOpMonitor("JSONWriter")
@staticmethod
def load_existing_json():
def load_existing_json(dir=None):
"""
Look for an existing json under :meth:`logger.get_logger_dir()` named "stats.json",
Look for an existing json under dir (defaults to
:meth:`logger.get_logger_dir()`) named "stats.json",
and return the loaded list of statistics if found. Returns None otherwise.
"""
dir = logger.get_logger_dir()
if dir is None:
dir = logger.get_logger_dir()
fname = os.path.join(dir, JSONWriter.FILENAME)
if tf.gfile.Exists(fname):
with open(fname) as f:
......@@ -321,12 +323,12 @@ class JSONWriter(MonitorBase):
return None
@staticmethod
def load_existing_epoch_number():
def load_existing_epoch_number(dir=None):
"""
Try to load the latest epoch number from an existing json stats file (if any).
Returns None if not found.
"""
stats = JSONWriter.load_existing_json()
stats = JSONWriter.load_existing_json(dir)
try:
return int(stats[-1]['epoch_num'])
except Exception:
......
......@@ -206,7 +206,7 @@ class AutoResumeTrainConfig(TrainConfig):
"""
found_sessinit = False
if always_resume or 'session_init' not in kwargs:
sessinit = self._get_sessinit_resume()
sessinit = self.get_sessinit_resume()
if sessinit is not None:
found_sessinit = True
path = sessinit.path
......@@ -219,7 +219,7 @@ class AutoResumeTrainConfig(TrainConfig):
found_last_epoch = False
if always_resume or 'starting_epoch' not in kwargs:
last_epoch = self._get_last_epoch()
last_epoch = JSONWriter.load_existing_epoch_number()
if last_epoch is not None:
found_last_epoch = True
now_epoch = last_epoch + 1
......@@ -231,14 +231,13 @@ class AutoResumeTrainConfig(TrainConfig):
super(AutoResumeTrainConfig, self).__init__(**kwargs)
def _get_sessinit_resume(self):
logdir = logger.get_logger_dir()
if not logdir:
@staticmethod
def get_sessinit_resume(dir=None):
if dir is None:
dir = logger.get_logger_dir()
if not dir:
return None
path = os.path.join(logdir, 'checkpoint')
path = os.path.join(dir, 'checkpoint')
if not tf.gfile.Exists(path):
return None
return SaverRestore(path)
def _get_last_epoch(self):
return JSONWriter.load_existing_epoch_number()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment