Commit f73717ab authored by Yuxin Wu's avatar Yuxin Wu

fix UPDATE_OPS collection (#81)

parent 89b0c256
......@@ -6,12 +6,11 @@ The backward compatibilty will be __preserved for several months__, with a depre
so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2017/05/06](https://github.com/ppwwyyxx/tensorpack/commit/0774ec66e66075486f6a36aba63cc2a151b9fec8).
`replace_get_variable` was deprecated in favor of the official `custom_getter` interface.
`{freeze,remap}_get_variable` was renamed to `{freeze,remap}_variables`.
+ [2017/04/09](https://github.com/ppwwyyxx/tensorpack/commit/5beab907895aec36bdcaed62e25b976aad7979b8).
`ParamRestore` was renamed to `DictRestore`.
+ [2017/03/16](https://github.com/ppwwyyxx/tensorpack/commit/ccae46f4a3ca89dc3df901a338eef8447d19a730).
......
......@@ -5,21 +5,25 @@
""" Graph related callbacks"""
import tensorflow as tf
from ..utils import logger
from .base import Callback
__all__ = ['RunOp']
__all__ = ['RunOp', 'RunUpdateOps']
class RunOp(Callback):
""" Run an Op. """
def __init__(self, setup_func,
run_before=True, run_as_trigger=True):
run_before=True, run_as_trigger=True, run_step=False):
"""
Args:
setup_func: a function that returns the Op in the graph
run_before (bool): run the Op before training
run_epoch (bool): run the Op on every epoch trigger
run_as_trigger (bool): run the Op on every trigger
run_step (bool): run the Op every step (along with training)
Examples:
The `DQN Example
......@@ -29,10 +33,15 @@ class RunOp(Callback):
self.setup_func = setup_func
self.run_before = run_before
self.run_as_trigger = run_as_trigger
self.run_step = run_step
def _setup_graph(self):
self._op = self.setup_func()
def _before_run(self, _):
if self.run_step:
return [self._op]
def _before_train(self):
if self.run_before:
self._op.run()
......@@ -40,3 +49,20 @@ class RunOp(Callback):
def _trigger(self):
if self.run_as_trigger:
self._op.run()
class RunUpdateOps(RunOp):
"""
Run ops from the collection UPDATE_OPS every step
"""
def __init__(self, collection=tf.GraphKeys.UPDATE_OPS):
def f():
ops = tf.get_collection(collection)
if ops:
logger.info("Applying UPDATE_OPS collection of {} ops.".format(len(ops)))
return tf.group(*ops, name='update_ops')
else:
return tf.no_op(name='empty_update_ops')
super(RunUpdateOps, self).__init__(
f, run_before=False, run_as_trigger=False, run_step=True)
......@@ -128,9 +128,9 @@ class ModelDesc(object):
It calls :meth:`ModelDesc._get_cost()` which by default returns
``self.cost``. You can override :meth:`_get_cost()` if needed.
This function also applies tfslim collections to the cost automatically,
including ``tf.GraphKeys.REGULARIZATION_LOSSES`` and ``tf.GraphKeys.UPDATE_OPS``.
This is because slim users would expect the regularizer being automatically applied once used in slim layers.
This function also applies the collection
``tf.GraphKeys.REGULARIZATION_LOSSES``to the cost automatically.
Because slim users would expect the regularizer being automatically applied once used in slim layers.
"""
cost = self._get_cost()
return apply_slim_collections(cost)
......
......@@ -10,7 +10,7 @@ import six
from ..utils import logger
from ..utils.develop import deprecated
from ..utils.argtools import memoized
from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.naming import TOWER_FREEZE_KEYS
from ..tfutils import get_tensors_by_names, TowerContext, get_op_tensor_name
from ..tfutils.collection import freeze_collection
......@@ -188,7 +188,7 @@ class PredictorTowerBuilder(object):
# No matter where this get called, clear any existing name scope.
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS), \
freeze_collection(TOWER_FREEZE_KEYS), \
TowerContext(towername, device=device, is_training=False):
self._fn(tower)
......
......@@ -8,7 +8,6 @@ from tabulate import tabulate
from ..utils import logger
from .summary import add_moving_summary
from .tower import get_current_tower_context
__all__ = ['describe_model', 'get_shape_str', 'apply_slim_collections']
......@@ -53,10 +52,7 @@ def get_shape_str(tensors):
def apply_slim_collections(cost):
"""
Apply slim collections to the cost, including:
1. adding the cost with the regularizers in ``tf.GraphKeys.REGULARIZATION_LOSSES``.
2. make the cost depend on ``tf.GraphKeys.UPDATE_OPS``.
Add the cost with the regularizers in ``tf.GraphKeys.REGULARIZATION_LOSSES``.
Args:
cost: a scalar tensor
......@@ -70,15 +66,4 @@ def apply_slim_collections(cost):
reg_loss = tf.add_n(list(regulization_losses), name="regularize_loss")
cost = tf.add(reg_loss, cost, name='total_cost')
add_moving_summary(reg_loss, cost)
# As these batch-norm statistics quickly accumulate, there is no significant loss of accuracy
# if only the main tower handles all batch-normalization updates, which are then shared across
# the towers
ctx = get_current_tower_context()
if ctx is not None and ctx.is_main_training_tower:
non_grad_updates = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
if non_grad_updates:
logger.info("Applying UPDATE_OPS collection from the first tower on cost.")
with tf.control_dependencies(non_grad_updates):
cost = tf.identity(cost, name='cost_with_update')
return cost
......@@ -9,7 +9,7 @@ import re
from six.moves import zip, range
from ..utils import logger
from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..utils.naming import TOWER_FREEZE_KEYS
from ..utils.concurrency import LoopThread
from ..tfutils.tower import TowerContext
from ..tfutils.collection import backup_collection, restore_collection
......@@ -50,8 +50,8 @@ class MultiGPUTrainer(Trainer):
ret.append(func())
if idx == 0:
# avoid repeated summary from each device
backup = backup_collection(SUMMARY_BACKUP_KEYS)
# avoid repeated summary & update_ops from each device
backup = backup_collection(TOWER_FREEZE_KEYS)
restore_collection(backup)
return ret
......
......@@ -24,6 +24,8 @@ INPUTS_KEY = 'INPUTS_METAINFO'
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS + [tf.GraphKeys.UPDATE_OPS]
# export all upper case variables
all_local_names = locals().keys()
__all__ = [x for x in all_local_names if x.isupper()]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment