Commit c24697ee authored by Yuxin Wu's avatar Yuxin Wu

try resume training..

parent e5f9f83a
......@@ -15,7 +15,7 @@ from .varmanip import (SessionUpdate, get_savename_from_varname,
__all__ = ['SessionInit', 'SaverRestore', 'SaverRestoreRelaxed',
'ParamRestore', 'DictRestore', 'ChainInit',
'JustCurrentSession', 'get_model_loader']
'JustCurrentSession', 'get_model_loader', 'TryResumeTraining']
class SessionInit(object):
......@@ -235,3 +235,18 @@ def get_model_loader(filename):
return DictRestore(np.load(filename, encoding='latin1').item())
else:
return SaverRestore(filename)
def TryResumeTraining():
"""
Load latest checkpoint from LOG_DIR, if there is one.
Returns:
SessInit: either a :class:`JustCurrentSession`, or a :class:`SaverRestore`.
"""
if not logger.LOG_DIR:
return JustCurrentSession()
path = os.path.join(logger.LOG_DIR, 'checkpoint')
if not os.path.isfile(path):
return JustCurrentSession()
return SaverRestore(path)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment