Commit f26b9b59 authored by Yuxin Wu's avatar Yuxin Wu

add RunUpdateOps to the default callbacks

parent f73717ab
...@@ -7,7 +7,7 @@ import tensorflow as tf ...@@ -7,7 +7,7 @@ import tensorflow as tf
from ..callbacks import ( from ..callbacks import (
Callbacks, MovingAverageSummary, Callbacks, MovingAverageSummary,
ProgressBar, MergeAllSummaries, ProgressBar, MergeAllSummaries,
TFSummaryWriter, JSONWriter, ScalarPrinter) TFSummaryWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger from ..utils import logger
...@@ -44,7 +44,7 @@ class TrainConfig(object): ...@@ -44,7 +44,7 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training. callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are is only used to provide the defaults. The defaults are
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries(), LoadEpochNum()]``. The list of ``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries(), RunUpdateOps()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``. callbacks that will be used in the end are ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`. monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``. Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``.
...@@ -95,7 +95,8 @@ class TrainConfig(object): ...@@ -95,7 +95,8 @@ class TrainConfig(object):
extra_callbacks = [ extra_callbacks = [
MovingAverageSummary(), MovingAverageSummary(),
ProgressBar(), ProgressBar(),
MergeAllSummaries()] MergeAllSummaries(),
RunUpdateOps()]
self._callbacks = callbacks + extra_callbacks self._callbacks = callbacks + extra_callbacks
assert_type(self._callbacks, list) assert_type(self._callbacks, list)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment