Commit aa7e18fc authored by Yuxin Wu's avatar Yuxin Wu

sessinit

parent acb441ca
...@@ -37,6 +37,10 @@ def _import_external_ops(message): ...@@ -37,6 +37,10 @@ def _import_external_ops(message):
else: else:
from tensorflow.python.ops import gen_nccl_ops # noqa from tensorflow.python.ops import gen_nccl_ops # noqa
return return
if 'ZMQConnection' in message:
import zmq_ops
return
logger.error("Unhandled error: " + message)
def guess_inputs(input_dir): def guess_inputs(input_dir):
......
...@@ -306,12 +306,14 @@ class JSONWriter(MonitorBase): ...@@ -306,12 +306,14 @@ class JSONWriter(MonitorBase):
return NoOpMonitor("JSONWriter") return NoOpMonitor("JSONWriter")
@staticmethod @staticmethod
def load_existing_json(): def load_existing_json(dir=None):
""" """
Look for an existing json under :meth:`logger.get_logger_dir()` named "stats.json", Look for an existing json under dir (defaults to
:meth:`logger.get_logger_dir()`) named "stats.json",
and return the loaded list of statistics if found. Returns None otherwise. and return the loaded list of statistics if found. Returns None otherwise.
""" """
dir = logger.get_logger_dir() if dir is None:
dir = logger.get_logger_dir()
fname = os.path.join(dir, JSONWriter.FILENAME) fname = os.path.join(dir, JSONWriter.FILENAME)
if tf.gfile.Exists(fname): if tf.gfile.Exists(fname):
with open(fname) as f: with open(fname) as f:
...@@ -321,12 +323,12 @@ class JSONWriter(MonitorBase): ...@@ -321,12 +323,12 @@ class JSONWriter(MonitorBase):
return None return None
@staticmethod @staticmethod
def load_existing_epoch_number(): def load_existing_epoch_number(dir=None):
""" """
Try to load the latest epoch number from an existing json stats file (if any). Try to load the latest epoch number from an existing json stats file (if any).
Returns None if not found. Returns None if not found.
""" """
stats = JSONWriter.load_existing_json() stats = JSONWriter.load_existing_json(dir)
try: try:
return int(stats[-1]['epoch_num']) return int(stats[-1]['epoch_num'])
except Exception: except Exception:
......
...@@ -206,7 +206,7 @@ class AutoResumeTrainConfig(TrainConfig): ...@@ -206,7 +206,7 @@ class AutoResumeTrainConfig(TrainConfig):
""" """
found_sessinit = False found_sessinit = False
if always_resume or 'session_init' not in kwargs: if always_resume or 'session_init' not in kwargs:
sessinit = self._get_sessinit_resume() sessinit = self.get_sessinit_resume()
if sessinit is not None: if sessinit is not None:
found_sessinit = True found_sessinit = True
path = sessinit.path path = sessinit.path
...@@ -219,7 +219,7 @@ class AutoResumeTrainConfig(TrainConfig): ...@@ -219,7 +219,7 @@ class AutoResumeTrainConfig(TrainConfig):
found_last_epoch = False found_last_epoch = False
if always_resume or 'starting_epoch' not in kwargs: if always_resume or 'starting_epoch' not in kwargs:
last_epoch = self._get_last_epoch() last_epoch = JSONWriter.load_existing_epoch_number()
if last_epoch is not None: if last_epoch is not None:
found_last_epoch = True found_last_epoch = True
now_epoch = last_epoch + 1 now_epoch = last_epoch + 1
...@@ -231,14 +231,13 @@ class AutoResumeTrainConfig(TrainConfig): ...@@ -231,14 +231,13 @@ class AutoResumeTrainConfig(TrainConfig):
super(AutoResumeTrainConfig, self).__init__(**kwargs) super(AutoResumeTrainConfig, self).__init__(**kwargs)
def _get_sessinit_resume(self): @staticmethod
logdir = logger.get_logger_dir() def get_sessinit_resume(dir=None):
if not logdir: if dir is None:
dir = logger.get_logger_dir()
if not dir:
return None return None
path = os.path.join(logdir, 'checkpoint') path = os.path.join(dir, 'checkpoint')
if not tf.gfile.Exists(path): if not tf.gfile.Exists(path):
return None return None
return SaverRestore(path) return SaverRestore(path)
def _get_last_epoch(self):
return JSONWriter.load_existing_epoch_number()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment