Commit 7abf4ace authored by Yuxin Wu's avatar Yuxin Wu

Add AutoResumeTrainConfig to really do auto resuming from log dir (fix #660)

parent f130c10f
......@@ -8,6 +8,8 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2018/03/12] `JSONWriter` used a different file name, and will not automatically restore epoch number.
`AutoResumeTrainConfig` was added to support resuming better.
+ [2017/10/21]
tensorpack is gradually switching to a new Trainer API.
The old API will keep working for a while. See [issue](https://github.com/ppwwyyxx/tensorpack/issues/458)
......
......@@ -7,6 +7,7 @@ import os
import numpy as np
import shutil
import time
from datetime import datetime
import operator
from collections import defaultdict
import six
......@@ -213,7 +214,7 @@ class TFEventWriter(TrainingMonitor):
"""
if logdir is None:
logdir = logger.get_logger_dir()
assert os.path.isdir(logdir), logdir
assert tf.gfile.IsDirectory(logdir), logdir
self._logdir = logdir
self._max_queue = max_queue
self._flush_secs = flush_secs
......@@ -249,13 +250,12 @@ class TFEventWriter(TrainingMonitor):
class JSONWriter(TrainingMonitor):
"""
Write all scalar data to a json file under ``logger.get_logger_dir()``, grouped by their global step.
This monitor also attemps to recover the epoch number during setup,
if an existing json file is found at the same place.
If found an earlier json history file, will append to it.
"""
FILENAME = 'stat.json'
FILENAME = 'stats.json'
"""
The name of the json file.
The name of the json file. Do not change it.
"""
def __new__(cls):
......@@ -265,26 +265,61 @@ class JSONWriter(TrainingMonitor):
logger.warn("logger directory was not set. Ignore JSONWriter.")
return NoOpMonitor()
def _before_train(self):
self._dir = logger.get_logger_dir()
self._fname = os.path.join(self._dir, self.FILENAME)
@staticmethod
def load_existing_json():
"""
Look for an existing json under :meth:`logger.get_logger_dir()` named "stats.json",
and return the loaded list of statistics if found. Returns None otherwise.
"""
dir = logger.get_logger_dir()
fname = os.path.join(dir, JSONWriter.FILENAME)
if tf.gfile.Exists(fname):
with open(fname) as f:
stats = json.load(f)
assert isinstance(stats, list), type(stats)
return stats
return None
if os.path.isfile(self._fname):
logger.info("Found JSON at {}, will append to it.".format(self._fname))
with open(self._fname) as f:
self._stats = json.load(f)
assert isinstance(self._stats, list), type(self._stats)
@staticmethod
def load_existing_epoch_number():
"""
Try to load the latest epoch number from an existing json stats file (if any).
Returns None if not found.
"""
stats = JSONWriter.load_existing_json()
try:
return int(stats[-1]['epoch_num'])
except Exception:
return None
def _before_train(self):
stats = JSONWriter.load_existing_json()
self._fname = os.path.join(logger.get_logger_dir(), JSONWriter.FILENAME)
if stats is not None:
try:
epoch = self._stats[-1]['epoch_num'] + 1
epoch = stats[-1]['epoch_num'] + 1
except Exception:
pass
epoch = None
starting_epoch = self.trainer.loop.starting_epoch
if epoch is None or epoch == starting_epoch:
logger.info("Found existing JSON inside {}, will append to it.".format(logger.get_logger_dir()))
self._stats = stats
else:
# TODO is this a good idea?
logger.info("Found history statistics from JSON. "
"Rename the first epoch of this training to epoch #{}.".format(epoch))
self.trainer.loop.starting_epoch = epoch
self.trainer.loop._epoch_num = epoch - 1
logger.warn(
"History epoch value {} from JSON is not the predecessor of the starting_epoch value {}".format(
epoch - 1, starting_epoch))
logger.warn("If you want to resume old training, either use `AutoResumeTrainConfig` "
"or correctly set the starting_epoch yourself to avoid inconsistency. "
"Epoch number will not be automatically loaded by JSONWriter.")
backup_fname = JSONWriter.FILENAME + '.' + datetime.now().strftime('%m%d-%H%M%S')
backup_fname = os.path.join(logger.get_logger_dir(), backup_fname)
logger.warn("Now, we will start training at epoch {} and backup old json to {}".format(
self.trainer.loop.starting_epoch, backup_fname))
shutil.move(self._fname, backup_fname)
self._stats = []
else:
self._stats = []
self._stat_now = {}
......
......@@ -105,7 +105,7 @@ class SaverRestore(SessionInit):
logger.warn("SaverRestore expect a TF checkpoint, but got a model path '{}'.".format(model_path) +
" To load from a dict, use 'DictRestore'.")
model_path = get_checkpoint_path(model_path)
self.path = model_path
self.path = model_path # attribute used by AutoResumeTrainConfig!
self.prefix = prefix
self.ignore = [i if i.endswith(':0') else i + ':0' for i in ignore]
......@@ -262,7 +262,7 @@ def get_model_loader(filename):
return SaverRestore(filename)
@deprecated("Write the logic yourself!", "2018-06-01")
@deprecated("Write the logic yourself or use AutoResumeTrainConfig!", "2018-06-01")
def TryResumeTraining():
"""
Try loading latest checkpoint from ``logger.get_logger_dir()``, only if there is one.
......
......@@ -16,7 +16,7 @@ from ..utils.argtools import call_only_once
from ..tfutils import get_global_step_value
from ..tfutils.tower import TowerFuncWrapper
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from ..tfutils.sesscreate import ReuseSessionCreator, NewSessionCreator
from ..callbacks.steps import MaintainStepCounter
......@@ -46,9 +46,10 @@ class TrainLoop(object):
"""
Configure the loop given the settings.
"""
self.starting_epoch = starting_epoch
self.max_epoch = max_epoch
self.steps_per_epoch = steps_per_epoch
self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch)
self.steps_per_epoch = int(steps_per_epoch)
assert self.steps_per_epoch > 0 and self.max_epoch > 0
self._epoch_num = starting_epoch - 1
......@@ -215,6 +216,8 @@ class Trainer(object):
session_creator (tf.train.SessionCreator):
session_init (sessinit.SessionInit):
"""
assert isinstance(session_creator, tf.train.SessionCreator), session_creator
assert isinstance(session_init, SessionInit), session_init
session_init._setup_graph()
logger.info("Creating the session ...")
......
......@@ -2,6 +2,9 @@
# -*- coding: utf-8 -*-
# File: config.py
import os
import tensorflow as tf
from ..callbacks import (
MovingAverageSummary,
ProgressBar, MergeAllSummaries,
......@@ -9,11 +12,11 @@ from ..callbacks import (
from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger
from ..tfutils import (JustCurrentSession, SessionInit)
from ..tfutils.sessinit import JustCurrentSession, SessionInit, SaverRestore
from ..tfutils.sesscreate import NewSessionCreator
from ..input_source import InputSource
__all__ = ['TrainConfig', 'DEFAULT_CALLBACKS', 'DEFAULT_MONITORS']
__all__ = ['TrainConfig', 'AutoResumeTrainConfig', 'DEFAULT_CALLBACKS', 'DEFAULT_MONITORS']
def DEFAULT_CALLBACKS():
......@@ -145,7 +148,6 @@ class TrainConfig(object):
self.starting_epoch = int(starting_epoch)
self.max_epoch = int(max_epoch)
assert self.steps_per_epoch > 0 and self.max_epoch > 0
# Tower stuff are for Trainer v1 only:
nr_tower = max(nr_tower, 1)
......@@ -167,3 +169,69 @@ class TrainConfig(object):
@property
def callbacks(self): # disable setter
return self._callbacks
class AutoResumeTrainConfig(TrainConfig):
"""
Same as :class:`TrainConfig`, but does the following to automatically
resume from training:
1. If a checkpoint was found in :meth:`logger.get_logger_dir()`, set
`session_init` option to load it.
2. If a JSON history was found in :meth:`logger.get_logger_dir()`, try to
load the epoch number from it and set the `starting_epoch` option to
continue training.
You can choose to let the above two option to either overwrite or
not overwrite user-provided arguments, as explained below.
"""
def __init__(self, always_resume=True, **kwargs):
"""
Args:
always_resume (bool): If False, user-provided arguments
`session_init` and `starting_epoch` will take priority.
Otherwise, resume will take priority.
kwargs: same as in :class:`TrainConfig`.
Notes:
The main goal of this class is to let a training job to resume
without changing any line of code or command line arguments.
So it's useful to let resume take priority over user-provided arguments sometimes:
If your training starts from a pretrained model,
you would want it to use user-provided model loader at the
beginning, but a "resume" model loader when the job was
interrupted and restarted.
"""
if always_resume or 'session_init' not in kwargs:
sessinit = self._get_sessinit_resume()
if sessinit is not None:
path = sessinit.path
if 'session_init' in kwargs:
logger.info("Found checkpoint at {}. "
"session_init arguments will be overwritten.".format(path))
else:
logger.info("Will load checkpoint at {}.".format(path))
kwargs['session_init'] = sessinit
if always_resume or 'starting_epoch' not in kwargs:
last_epoch = self._get_last_epoch()
if last_epoch is not None:
now_epoch = last_epoch + 1
logger.info("Found history statistics from JSON. "
"Overwrite the starting epoch to epoch #{}.".format(now_epoch))
kwargs['starting_epoch'] = now_epoch
super(AutoResumeTrainConfig, self).__init__(**kwargs)
def _get_sessinit_resume(self):
logdir = logger.get_logger_dir()
if not logdir:
return None
path = os.path.join(logdir, 'checkpoint')
if not tf.gfile.Exists(path):
return None
return SaverRestore(path)
def _get_last_epoch(self):
return JSONWriter.load_existing_epoch_number()
......@@ -65,9 +65,9 @@ def launch_train_with_config(config, trainer):
.. code-block:: python
# with the old trainer:
# With the old trainer:
SyncMultiGPUTrainerParameterServer(config, ps_device='gpu').train()
# with the new trainer:
# With the current version of trainer:
launch_train_with_config(
config, SyncMultiGPUTrainerParameterServer(towers, ps_device='gpu'))
"""
......
......@@ -49,6 +49,7 @@ class SimpleTrainer(SingleCostTrainer):
Single-GPU single-cost single-tower trainer.
"""
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
logger.info("Building graph for a single training tower ...")
with TowerContext('', is_training=True):
grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)()
opt = get_opt_fn()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment