Commit c24697ee authored by Yuxin Wu's avatar Yuxin Wu

try resume training..

parent e5f9f83a
...@@ -15,7 +15,7 @@ from .varmanip import (SessionUpdate, get_savename_from_varname, ...@@ -15,7 +15,7 @@ from .varmanip import (SessionUpdate, get_savename_from_varname,
__all__ = ['SessionInit', 'SaverRestore', 'SaverRestoreRelaxed', __all__ = ['SessionInit', 'SaverRestore', 'SaverRestoreRelaxed',
'ParamRestore', 'DictRestore', 'ChainInit', 'ParamRestore', 'DictRestore', 'ChainInit',
'JustCurrentSession', 'get_model_loader'] 'JustCurrentSession', 'get_model_loader', 'TryResumeTraining']
class SessionInit(object): class SessionInit(object):
...@@ -235,3 +235,18 @@ def get_model_loader(filename): ...@@ -235,3 +235,18 @@ def get_model_loader(filename):
return DictRestore(np.load(filename, encoding='latin1').item()) return DictRestore(np.load(filename, encoding='latin1').item())
else: else:
return SaverRestore(filename) return SaverRestore(filename)
def TryResumeTraining():
"""
Load latest checkpoint from LOG_DIR, if there is one.
Returns:
SessInit: either a :class:`JustCurrentSession`, or a :class:`SaverRestore`.
"""
if not logger.LOG_DIR:
return JustCurrentSession()
path = os.path.join(logger.LOG_DIR, 'checkpoint')
if not os.path.isfile(path):
return JustCurrentSession()
return SaverRestore(path)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment