Commit 596d8008 authored by Yuxin Wu's avatar Yuxin Wu

use a callback to maintain global_step

parent bbb47815
......@@ -7,7 +7,7 @@ import tensorflow as tf
import numpy as np
import time
from tensorpack import (FeedfreeTrainerBase, TowerContext,
get_global_step_var, QueueInput, ModelDesc)
QueueInput, ModelDesc)
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.gradproc import apply_grad_processors, CheckGradient
from tensorpack.dataflow import DataFlow
......@@ -92,8 +92,7 @@ class GANTrainer(FeedfreeTrainerBase):
self.model.d_loss, var_list=self.model.d_vars)
grads = apply_grad_processors(
grads, self.model.get_gradient_processor_d())
self.d_min = self.config.optimizer.apply_gradients(
grads, get_global_step_var(), name='d_op')
self.d_min = self.config.optimizer.apply_gradients(grads, name='d_op')
self.train_op = self.d_min
......
......@@ -11,11 +11,15 @@ from six.moves import zip
import tqdm
from ..utils import logger, get_tqdm_kwargs
from ..utils.naming import MOVING_SUMMARY_VARS_KEY
from ..utils.naming import (
MOVING_SUMMARY_VARS_KEY,
GLOBAL_STEP_INCR_VAR_NAME,
LOCAL_STEP_OP_NAME)
from ..tfutils.common import get_op_tensor_name, get_global_step_var
from .base import Callback
__all__ = ['StepStatPrinter', 'SummaryMovingAverage', 'ProgressBar']
__all__ = ['StepStatPrinter', 'MaintainStepCounter',
'SummaryMovingAverage', 'ProgressBar']
class StepStatPrinter(Callback):
......@@ -41,6 +45,24 @@ class StepStatPrinter(Callback):
logger.info("{}: {}".format(n, v))
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph and also creates the local step tensor.
This callback is always enabled by the trainer, and you wouldn't need to
use it.
"""
def _setup_graph(self):
# ensure it exists
get_global_step_var()
self.gs_incr_var = self.trainer.sess.graph.get_tensor_by_name(GLOBAL_STEP_INCR_VAR_NAME)
self.local_step = tf.mod(
self.gs_incr_var, self.trainer.config.step_per_epoch,
name=LOCAL_STEP_OP_NAME)
def _extra_fetches(self):
return [self.gs_incr_var.op]
class SummaryMovingAverage(Callback):
""" Maintain the moving average of the tensors
in every step, and summarize them. Enabled by default.
......
......@@ -3,12 +3,14 @@
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from ..utils.naming import GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_OP_NAME
import tensorflow as tf
from copy import copy
import six
from contextlib import contextmanager
from ..utils.naming import GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_OP_NAME, GLOBAL_STEP_INCR_OP_NAME
from ..utils.argtools import memoized
__all__ = ['get_default_sess_config',
'get_global_step_value',
'get_global_step_var',
......@@ -43,6 +45,7 @@ def get_default_sess_config(mem_fraction=0.99):
return conf
@memoized
def get_global_step_var():
"""
Returns:
......@@ -54,11 +57,14 @@ def get_global_step_var():
except KeyError:
scope = tf.get_variable_scope()
assert scope.name == '', \
"Creating global_step_var under a variable scope would cause problems!"
with tf.variable_scope(scope, reuse=False):
"The global_step variable should be created under the root variable scope!"
with tf.variable_scope(scope, reuse=False), \
tf.name_scope(None):
var = tf.get_variable(GLOBAL_STEP_OP_NAME,
initializer=0,
trainable=False, dtype=tf.int32)
# also create the incr operation
tf.assign_add(var, 1, name=GLOBAL_STEP_INCR_OP_NAME)
return var
......
......@@ -13,7 +13,7 @@ from .config import TrainConfig
from ..utils import logger
from ..utils.timer import timed_operation
from ..callbacks import StatHolder
from ..tfutils import get_global_step_var, get_global_step_value
from ..tfutils import get_global_step_value
from ..tfutils.modelutils import describe_model
from ..tfutils.summary import create_scalar_summary
......@@ -144,7 +144,6 @@ class Trainer(object):
"""
self._setup()
describe_model()
get_global_step_var() # ensure such var exists
# some final operations that might modify the graph
logger.info("Setup callbacks ...")
self.config.callbacks.setup_graph(weakref.proxy(self))
......
......@@ -6,7 +6,7 @@ import tensorflow as tf
from ..callbacks import (
Callbacks, SummaryMovingAverage,
StatPrinter, ProgressBar)
StatPrinter, ProgressBar, MaintainStepCounter)
from ..dataflow.base import DataFlow
from ..models import ModelDesc
from ..utils import logger
......@@ -42,6 +42,8 @@ class TrainConfig(object):
is only used to provide the defaults. The defaults are
``[SummaryMovingAverage(), ProgressBar(), StatPrinter()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
Note that ``StatPrinter`` should be the last one to be able to print
stats generated by other callbacks.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch.
......@@ -83,7 +85,7 @@ class TrainConfig(object):
assert_type(callbacks, list)
if extra_callbacks is None:
extra_callbacks = [SummaryMovingAverage(), ProgressBar(), StatPrinter()]
self.callbacks = callbacks + extra_callbacks
self.callbacks = [MaintainStepCounter()] + callbacks + extra_callbacks
assert_type(self.callbacks, list)
self.callbacks = Callbacks(self.callbacks)
......
......@@ -6,7 +6,6 @@
import tensorflow as tf
from ..utils import logger
from ..tfutils import get_global_step_var
from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import apply_grad_processors
from .input_data import QueueInput, FeedfreeInput
......@@ -101,8 +100,7 @@ class SimpleFeedfreeTrainer(
cost, grads = self._get_cost_and_grad()
grads = apply_grad_processors(grads, self.model.get_gradient_processor())
self.train_op = self.config.optimizer.apply_gradients(
grads, get_global_step_var(), name='min_op')
self.train_op = self.config.optimizer.apply_gradients(grads, name='min_op')
# skip training
# self.train_op = tf.group(*self.dequed_inputs)
......
......@@ -11,8 +11,7 @@ from six.moves import zip, range
from ..utils import logger
from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.concurrency import LoopThread
from ..tfutils import (backup_collection, restore_collection,
get_global_step_var, TowerContext)
from ..tfutils import (backup_collection, restore_collection, TowerContext)
from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from .base import Trainer
......@@ -112,8 +111,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
grads = SyncMultiGPUTrainer._average_grads(grad_list)
grads = apply_grad_processors(grads, self.model.get_gradient_processor())
self.train_op = self.config.optimizer.apply_gradients(
grads, get_global_step_var(), name='min_op')
self.train_op = self.config.optimizer.apply_gradients(grads, name='min_op')
class AsyncMultiGPUTrainer(MultiGPUTrainer,
......@@ -163,8 +161,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
grad_list = [apply_grad_processors(g, gradprocs) for g in grad_list]
# use grad from the first tower for iteration in main thread
self.train_op = self.config.optimizer.apply_gradients(
grad_list[0], get_global_step_var(), name='min_op')
self.train_op = self.config.optimizer.apply_gradients(grad_list[0], name='min_op')
self._start_async_threads(grad_list)
......
......@@ -7,8 +7,9 @@ import tensorflow as tf
from .base import Trainer
from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext)
from ..tfutils import (get_tensors_by_names,
freeze_collection,
TowerContext)
from ..predict import OnlinePredictor, build_prediction_graph
from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput
......@@ -88,8 +89,7 @@ class SimpleTrainer(Trainer):
grads = apply_grad_processors(grads,
self.model.get_gradient_processor())
self.train_op = self.config.optimizer.apply_gradients(
grads, get_global_step_var(), name='min_op')
self.train_op = self.config.optimizer.apply_gradients(grads, name='min_op')
def _trigger_epoch(self):
if self.summary_op is not None:
......
......@@ -7,6 +7,12 @@ import tensorflow as tf
GLOBAL_STEP_OP_NAME = 'global_step'
GLOBAL_STEP_VAR_NAME = 'global_step:0'
GLOBAL_STEP_INCR_OP_NAME = 'global_step_incr'
GLOBAL_STEP_INCR_VAR_NAME = 'global_step_incr:0'
LOCAL_STEP_OP_NAME = 'local_step'
LOCAL_STEP_VAR_NAME = 'local_step:0'
# prefix of predict tower
PREDICT_TOWER = 'towerp'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment