Commit 2ba9c3cd authored by Yuxin Wu's avatar Yuxin Wu

auto load epoch number from JSON. (#171)

parent 7782e724
...@@ -143,6 +143,8 @@ class JSONWriter(TrainingMonitor): ...@@ -143,6 +143,8 @@ class JSONWriter(TrainingMonitor):
""" """
Write all scalar data to a json, grouped by their global step. Write all scalar data to a json, grouped by their global step.
""" """
FILENAME = 'stat.json'
def __new__(cls): def __new__(cls):
if logger.LOG_DIR: if logger.LOG_DIR:
return super(JSONWriter, cls).__new__(cls) return super(JSONWriter, cls).__new__(cls)
...@@ -152,7 +154,7 @@ class JSONWriter(TrainingMonitor): ...@@ -152,7 +154,7 @@ class JSONWriter(TrainingMonitor):
def _setup_graph(self): def _setup_graph(self):
self._dir = logger.LOG_DIR self._dir = logger.LOG_DIR
self._fname = os.path.join(self._dir, 'stat.json') self._fname = os.path.join(self._dir, self.FILENAME)
if os.path.isfile(self._fname): if os.path.isfile(self._fname):
# TODO make a backup first? # TODO make a backup first?
...@@ -160,6 +162,14 @@ class JSONWriter(TrainingMonitor): ...@@ -160,6 +162,14 @@ class JSONWriter(TrainingMonitor):
with open(self._fname) as f: with open(self._fname) as f:
self._stats = json.load(f) self._stats = json.load(f)
assert isinstance(self._stats, list), type(self._stats) assert isinstance(self._stats, list), type(self._stats)
try:
epoch = self._stats[-1]['epoch_num'] + 1
except Exception:
pass
else:
logger.info("Found training history from JSON, now starting from epoch number {}.".format(epoch))
self.trainer.config.starting_epoch = epoch
else: else:
self._stats = [] self._stats = []
self._stat_now = {} self._stat_now = {}
......
...@@ -6,7 +6,8 @@ import tensorflow as tf ...@@ -6,7 +6,8 @@ import tensorflow as tf
from ..callbacks import ( from ..callbacks import (
Callbacks, MovingAverageSummary, Callbacks, MovingAverageSummary,
ProgressBar, MergeAllSummaries) ProgressBar, MergeAllSummaries,
TFSummaryWriter, JSONWriter, ScalarPrinter)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger from ..utils import logger
...@@ -16,7 +17,6 @@ from ..tfutils import (JustCurrentSession, ...@@ -16,7 +17,6 @@ from ..tfutils import (JustCurrentSession,
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.optimizer import apply_grad_processors from ..tfutils.optimizer import apply_grad_processors
from .input_data import InputData from .input_data import InputData
from ..callbacks.monitor import TFSummaryWriter, JSONWriter, ScalarPrinter
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
...@@ -44,7 +44,7 @@ class TrainConfig(object): ...@@ -44,7 +44,7 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training. callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are is only used to provide the defaults. The defaults are
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries()]``. The list of ``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries(), LoadEpochNum()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``. callbacks that will be used in the end are ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`. monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``. Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment