Commit 596d8008 authored by Yuxin Wu's avatar Yuxin Wu

use a callback to maintain global_step

parent bbb47815
...@@ -7,7 +7,7 @@ import tensorflow as tf ...@@ -7,7 +7,7 @@ import tensorflow as tf
import numpy as np import numpy as np
import time import time
from tensorpack import (FeedfreeTrainerBase, TowerContext, from tensorpack import (FeedfreeTrainerBase, TowerContext,
get_global_step_var, QueueInput, ModelDesc) QueueInput, ModelDesc)
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.gradproc import apply_grad_processors, CheckGradient from tensorpack.tfutils.gradproc import apply_grad_processors, CheckGradient
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
...@@ -92,8 +92,7 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -92,8 +92,7 @@ class GANTrainer(FeedfreeTrainerBase):
self.model.d_loss, var_list=self.model.d_vars) self.model.d_loss, var_list=self.model.d_vars)
grads = apply_grad_processors( grads = apply_grad_processors(
grads, self.model.get_gradient_processor_d()) grads, self.model.get_gradient_processor_d())
self.d_min = self.config.optimizer.apply_gradients( self.d_min = self.config.optimizer.apply_gradients(grads, name='d_op')
grads, get_global_step_var(), name='d_op')
self.train_op = self.d_min self.train_op = self.d_min
......
...@@ -11,11 +11,15 @@ from six.moves import zip ...@@ -11,11 +11,15 @@ from six.moves import zip
import tqdm import tqdm
from ..utils import logger, get_tqdm_kwargs from ..utils import logger, get_tqdm_kwargs
from ..utils.naming import MOVING_SUMMARY_VARS_KEY from ..utils.naming import (
MOVING_SUMMARY_VARS_KEY,
GLOBAL_STEP_INCR_VAR_NAME,
LOCAL_STEP_OP_NAME)
from ..tfutils.common import get_op_tensor_name, get_global_step_var from ..tfutils.common import get_op_tensor_name, get_global_step_var
from .base import Callback from .base import Callback
__all__ = ['StepStatPrinter', 'SummaryMovingAverage', 'ProgressBar'] __all__ = ['StepStatPrinter', 'MaintainStepCounter',
'SummaryMovingAverage', 'ProgressBar']
class StepStatPrinter(Callback): class StepStatPrinter(Callback):
...@@ -41,6 +45,24 @@ class StepStatPrinter(Callback): ...@@ -41,6 +45,24 @@ class StepStatPrinter(Callback):
logger.info("{}: {}".format(n, v)) logger.info("{}: {}".format(n, v))
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph and also creates the local step tensor.
This callback is always enabled by the trainer, and you wouldn't need to
use it.
"""
def _setup_graph(self):
# ensure it exists
get_global_step_var()
self.gs_incr_var = self.trainer.sess.graph.get_tensor_by_name(GLOBAL_STEP_INCR_VAR_NAME)
self.local_step = tf.mod(
self.gs_incr_var, self.trainer.config.step_per_epoch,
name=LOCAL_STEP_OP_NAME)
def _extra_fetches(self):
return [self.gs_incr_var.op]
class SummaryMovingAverage(Callback): class SummaryMovingAverage(Callback):
""" Maintain the moving average of the tensors """ Maintain the moving average of the tensors
in every step, and summarize them. Enabled by default. in every step, and summarize them. Enabled by default.
......
...@@ -3,12 +3,14 @@ ...@@ -3,12 +3,14 @@
# File: common.py # File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from ..utils.naming import GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_OP_NAME
import tensorflow as tf import tensorflow as tf
from copy import copy from copy import copy
import six import six
from contextlib import contextmanager from contextlib import contextmanager
from ..utils.naming import GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_OP_NAME, GLOBAL_STEP_INCR_OP_NAME
from ..utils.argtools import memoized
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step_value', 'get_global_step_value',
'get_global_step_var', 'get_global_step_var',
...@@ -43,6 +45,7 @@ def get_default_sess_config(mem_fraction=0.99): ...@@ -43,6 +45,7 @@ def get_default_sess_config(mem_fraction=0.99):
return conf return conf
@memoized
def get_global_step_var(): def get_global_step_var():
""" """
Returns: Returns:
...@@ -54,11 +57,14 @@ def get_global_step_var(): ...@@ -54,11 +57,14 @@ def get_global_step_var():
except KeyError: except KeyError:
scope = tf.get_variable_scope() scope = tf.get_variable_scope()
assert scope.name == '', \ assert scope.name == '', \
"Creating global_step_var under a variable scope would cause problems!" "The global_step variable should be created under the root variable scope!"
with tf.variable_scope(scope, reuse=False): with tf.variable_scope(scope, reuse=False), \
tf.name_scope(None):
var = tf.get_variable(GLOBAL_STEP_OP_NAME, var = tf.get_variable(GLOBAL_STEP_OP_NAME,
initializer=0, initializer=0,
trainable=False, dtype=tf.int32) trainable=False, dtype=tf.int32)
# also create the incr operation
tf.assign_add(var, 1, name=GLOBAL_STEP_INCR_OP_NAME)
return var return var
......
...@@ -13,7 +13,7 @@ from .config import TrainConfig ...@@ -13,7 +13,7 @@ from .config import TrainConfig
from ..utils import logger from ..utils import logger
from ..utils.timer import timed_operation from ..utils.timer import timed_operation
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import get_global_step_var, get_global_step_value from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..tfutils.summary import create_scalar_summary from ..tfutils.summary import create_scalar_summary
...@@ -144,7 +144,6 @@ class Trainer(object): ...@@ -144,7 +144,6 @@ class Trainer(object):
""" """
self._setup() self._setup()
describe_model() describe_model()
get_global_step_var() # ensure such var exists
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks ...") logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self)) self.config.callbacks.setup_graph(weakref.proxy(self))
......
...@@ -6,7 +6,7 @@ import tensorflow as tf ...@@ -6,7 +6,7 @@ import tensorflow as tf
from ..callbacks import ( from ..callbacks import (
Callbacks, SummaryMovingAverage, Callbacks, SummaryMovingAverage,
StatPrinter, ProgressBar) StatPrinter, ProgressBar, MaintainStepCounter)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..models import ModelDesc from ..models import ModelDesc
from ..utils import logger from ..utils import logger
...@@ -42,6 +42,8 @@ class TrainConfig(object): ...@@ -42,6 +42,8 @@ class TrainConfig(object):
is only used to provide the defaults. The defaults are is only used to provide the defaults. The defaults are
``[SummaryMovingAverage(), ProgressBar(), StatPrinter()]``. The list of ``[SummaryMovingAverage(), ProgressBar(), StatPrinter()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``. callbacks that will be used in the end are ``callbacks + extra_callbacks``.
Note that ``StatPrinter`` should be the last one to be able to print
stats generated by other callbacks.
session_config (tf.ConfigProto): the config used to instantiate the session. session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session. session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch. starting_epoch (int): The index of the first epoch.
...@@ -83,7 +85,7 @@ class TrainConfig(object): ...@@ -83,7 +85,7 @@ class TrainConfig(object):
assert_type(callbacks, list) assert_type(callbacks, list)
if extra_callbacks is None: if extra_callbacks is None:
extra_callbacks = [SummaryMovingAverage(), ProgressBar(), StatPrinter()] extra_callbacks = [SummaryMovingAverage(), ProgressBar(), StatPrinter()]
self.callbacks = callbacks + extra_callbacks self.callbacks = [MaintainStepCounter()] + callbacks + extra_callbacks
assert_type(self.callbacks, list) assert_type(self.callbacks, list)
self.callbacks = Callbacks(self.callbacks) self.callbacks = Callbacks(self.callbacks)
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..tfutils import get_global_step_var
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .input_data import QueueInput, FeedfreeInput from .input_data import QueueInput, FeedfreeInput
...@@ -101,8 +100,7 @@ class SimpleFeedfreeTrainer( ...@@ -101,8 +100,7 @@ class SimpleFeedfreeTrainer(
cost, grads = self._get_cost_and_grad() cost, grads = self._get_cost_and_grad()
grads = apply_grad_processors(grads, self.model.get_gradient_processor()) grads = apply_grad_processors(grads, self.model.get_gradient_processor())
self.train_op = self.config.optimizer.apply_gradients( self.train_op = self.config.optimizer.apply_gradients(grads, name='min_op')
grads, get_global_step_var(), name='min_op')
# skip training # skip training
# self.train_op = tf.group(*self.dequed_inputs) # self.train_op = tf.group(*self.dequed_inputs)
......
...@@ -11,8 +11,7 @@ from six.moves import zip, range ...@@ -11,8 +11,7 @@ from six.moves import zip, range
from ..utils import logger from ..utils import logger
from ..utils.naming import SUMMARY_BACKUP_KEYS from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils import (backup_collection, restore_collection, from ..tfutils import (backup_collection, restore_collection, TowerContext)
get_global_step_var, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .base import Trainer from .base import Trainer
...@@ -112,8 +111,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -112,8 +111,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
grads = SyncMultiGPUTrainer._average_grads(grad_list) grads = SyncMultiGPUTrainer._average_grads(grad_list)
grads = apply_grad_processors(grads, self.model.get_gradient_processor()) grads = apply_grad_processors(grads, self.model.get_gradient_processor())
self.train_op = self.config.optimizer.apply_gradients( self.train_op = self.config.optimizer.apply_gradients(grads, name='min_op')
grads, get_global_step_var(), name='min_op')
class AsyncMultiGPUTrainer(MultiGPUTrainer, class AsyncMultiGPUTrainer(MultiGPUTrainer,
...@@ -163,8 +161,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -163,8 +161,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list] grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list]
# use grad from the first tower for iteration in main thread # use grad from the first tower for iteration in main thread
self.train_op = self.config.optimizer.apply_gradients( self.train_op = self.config.optimizer.apply_gradients(grad_list[0], name='min_op')
grad_list[0], get_global_step_var(), name='min_op')
self._start_async_threads(grad_list) self._start_async_threads(grad_list)
......
...@@ -7,8 +7,9 @@ import tensorflow as tf ...@@ -7,8 +7,9 @@ import tensorflow as tf
from .base import Trainer from .base import Trainer
from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import (get_tensors_by_names, freeze_collection, from ..tfutils import (get_tensors_by_names,
get_global_step_var, TowerContext) freeze_collection,
TowerContext)
from ..predict import OnlinePredictor, build_prediction_graph from ..predict import OnlinePredictor, build_prediction_graph
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput from .input_data import FeedInput
...@@ -88,8 +89,7 @@ class SimpleTrainer(Trainer): ...@@ -88,8 +89,7 @@ class SimpleTrainer(Trainer):
grads = apply_grad_processors(grads, grads = apply_grad_processors(grads,
self.model.get_gradient_processor()) self.model.get_gradient_processor())
self.train_op = self.config.optimizer.apply_gradients( self.train_op = self.config.optimizer.apply_gradients(grads, name='min_op')
grads, get_global_step_var(), name='min_op')
def _trigger_epoch(self): def _trigger_epoch(self):
if self.summary_op is not None: if self.summary_op is not None:
......
...@@ -7,6 +7,12 @@ import tensorflow as tf ...@@ -7,6 +7,12 @@ import tensorflow as tf
GLOBAL_STEP_OP_NAME = 'global_step' GLOBAL_STEP_OP_NAME = 'global_step'
GLOBAL_STEP_VAR_NAME = 'global_step:0' GLOBAL_STEP_VAR_NAME = 'global_step:0'
GLOBAL_STEP_INCR_OP_NAME = 'global_step_incr'
GLOBAL_STEP_INCR_VAR_NAME = 'global_step_incr:0'
LOCAL_STEP_OP_NAME = 'local_step'
LOCAL_STEP_VAR_NAME = 'local_step:0'
# prefix of predict tower # prefix of predict tower
PREDICT_TOWER = 'towerp' PREDICT_TOWER = 'towerp'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment