Commit f73717ab authored by Yuxin Wu's avatar Yuxin Wu

fix UPDATE_OPS collection (#81)

parent 89b0c256
...@@ -6,12 +6,11 @@ The backward compatibilty will be __preserved for several months__, with a depre ...@@ -6,12 +6,11 @@ The backward compatibilty will be __preserved for several months__, with a depre
so you won't need to look at here very often. so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here. TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2017/05/06](https://github.com/ppwwyyxx/tensorpack/commit/0774ec66e66075486f6a36aba63cc2a151b9fec8). + [2017/05/06](https://github.com/ppwwyyxx/tensorpack/commit/0774ec66e66075486f6a36aba63cc2a151b9fec8).
`replace_get_variable` was deprecated in favor of the official `custom_getter` interface. `replace_get_variable` was deprecated in favor of the official `custom_getter` interface.
`{freeze,remap}_get_variable` was renamed to `{freeze,remap}_variables`. `{freeze,remap}_get_variable` was renamed to `{freeze,remap}_variables`.
+ [2017/04/09](https://github.com/ppwwyyxx/tensorpack/commit/5beab907895aec36bdcaed62e25b976aad7979b8). + [2017/04/09](https://github.com/ppwwyyxx/tensorpack/commit/5beab907895aec36bdcaed62e25b976aad7979b8).
`ParamRestore` was renamed to `DictRestore`. `ParamRestore` was renamed to `DictRestore`.
+ [2017/03/16](https://github.com/ppwwyyxx/tensorpack/commit/ccae46f4a3ca89dc3df901a338eef8447d19a730). + [2017/03/16](https://github.com/ppwwyyxx/tensorpack/commit/ccae46f4a3ca89dc3df901a338eef8447d19a730).
......
...@@ -5,21 +5,25 @@ ...@@ -5,21 +5,25 @@
""" Graph related callbacks""" """ Graph related callbacks"""
import tensorflow as tf
from ..utils import logger
from .base import Callback from .base import Callback
__all__ = ['RunOp'] __all__ = ['RunOp', 'RunUpdateOps']
class RunOp(Callback): class RunOp(Callback):
""" Run an Op. """ """ Run an Op. """
def __init__(self, setup_func, def __init__(self, setup_func,
run_before=True, run_as_trigger=True): run_before=True, run_as_trigger=True, run_step=False):
""" """
Args: Args:
setup_func: a function that returns the Op in the graph setup_func: a function that returns the Op in the graph
run_before (bool): run the Op before training run_before (bool): run the Op before training
run_epoch (bool): run the Op on every epoch trigger run_as_trigger (bool): run the Op on every trigger
run_step (bool): run the Op every step (along with training)
Examples: Examples:
The `DQN Example The `DQN Example
...@@ -29,10 +33,15 @@ class RunOp(Callback): ...@@ -29,10 +33,15 @@ class RunOp(Callback):
self.setup_func = setup_func self.setup_func = setup_func
self.run_before = run_before self.run_before = run_before
self.run_as_trigger = run_as_trigger self.run_as_trigger = run_as_trigger
self.run_step = run_step
def _setup_graph(self): def _setup_graph(self):
self._op = self.setup_func() self._op = self.setup_func()
def _before_run(self, _):
if self.run_step:
return [self._op]
def _before_train(self): def _before_train(self):
if self.run_before: if self.run_before:
self._op.run() self._op.run()
...@@ -40,3 +49,20 @@ class RunOp(Callback): ...@@ -40,3 +49,20 @@ class RunOp(Callback):
def _trigger(self): def _trigger(self):
if self.run_as_trigger: if self.run_as_trigger:
self._op.run() self._op.run()
class RunUpdateOps(RunOp):
"""
Run ops from the collection UPDATE_OPS every step
"""
def __init__(self, collection=tf.GraphKeys.UPDATE_OPS):
def f():
ops = tf.get_collection(collection)
if ops:
logger.info("Applying UPDATE_OPS collection of {} ops.".format(len(ops)))
return tf.group(*ops, name='update_ops')
else:
return tf.no_op(name='empty_update_ops')
super(RunUpdateOps, self).__init__(
f, run_before=False, run_as_trigger=False, run_step=True)
...@@ -128,9 +128,9 @@ class ModelDesc(object): ...@@ -128,9 +128,9 @@ class ModelDesc(object):
It calls :meth:`ModelDesc._get_cost()` which by default returns It calls :meth:`ModelDesc._get_cost()` which by default returns
``self.cost``. You can override :meth:`_get_cost()` if needed. ``self.cost``. You can override :meth:`_get_cost()` if needed.
This function also applies tfslim collections to the cost automatically, This function also applies the collection
including ``tf.GraphKeys.REGULARIZATION_LOSSES`` and ``tf.GraphKeys.UPDATE_OPS``. ``tf.GraphKeys.REGULARIZATION_LOSSES``to the cost automatically.
This is because slim users would expect the regularizer being automatically applied once used in slim layers. Because slim users would expect the regularizer being automatically applied once used in slim layers.
""" """
cost = self._get_cost() cost = self._get_cost()
return apply_slim_collections(cost) return apply_slim_collections(cost)
......
...@@ -10,7 +10,7 @@ import six ...@@ -10,7 +10,7 @@ import six
from ..utils import logger from ..utils import logger
from ..utils.develop import deprecated from ..utils.develop import deprecated
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..utils.naming import SUMMARY_BACKUP_KEYS from ..utils.naming import TOWER_FREEZE_KEYS
from ..tfutils import get_tensors_by_names, TowerContext, get_op_tensor_name from ..tfutils import get_tensors_by_names, TowerContext, get_op_tensor_name
from ..tfutils.collection import freeze_collection from ..tfutils.collection import freeze_collection
...@@ -188,7 +188,7 @@ class PredictorTowerBuilder(object): ...@@ -188,7 +188,7 @@ class PredictorTowerBuilder(object):
# No matter where this get called, clear any existing name scope. # No matter where this get called, clear any existing name scope.
device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0' device = '/gpu:{}'.format(tower) if tower >= 0 else '/cpu:0'
with tf.name_scope(None), \ with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS), \ freeze_collection(TOWER_FREEZE_KEYS), \
TowerContext(towername, device=device, is_training=False): TowerContext(towername, device=device, is_training=False):
self._fn(tower) self._fn(tower)
......
...@@ -8,7 +8,6 @@ from tabulate import tabulate ...@@ -8,7 +8,6 @@ from tabulate import tabulate
from ..utils import logger from ..utils import logger
from .summary import add_moving_summary from .summary import add_moving_summary
from .tower import get_current_tower_context
__all__ = ['describe_model', 'get_shape_str', 'apply_slim_collections'] __all__ = ['describe_model', 'get_shape_str', 'apply_slim_collections']
...@@ -53,10 +52,7 @@ def get_shape_str(tensors): ...@@ -53,10 +52,7 @@ def get_shape_str(tensors):
def apply_slim_collections(cost): def apply_slim_collections(cost):
""" """
Apply slim collections to the cost, including: Add the cost with the regularizers in ``tf.GraphKeys.REGULARIZATION_LOSSES``.
1. adding the cost with the regularizers in ``tf.GraphKeys.REGULARIZATION_LOSSES``.
2. make the cost depend on ``tf.GraphKeys.UPDATE_OPS``.
Args: Args:
cost: a scalar tensor cost: a scalar tensor
...@@ -70,15 +66,4 @@ def apply_slim_collections(cost): ...@@ -70,15 +66,4 @@ def apply_slim_collections(cost):
reg_loss = tf.add_n(list(regulization_losses), name="regularize_loss") reg_loss = tf.add_n(list(regulization_losses), name="regularize_loss")
cost = tf.add(reg_loss, cost, name='total_cost') cost = tf.add(reg_loss, cost, name='total_cost')
add_moving_summary(reg_loss, cost) add_moving_summary(reg_loss, cost)
# As these batch-norm statistics quickly accumulate, there is no significant loss of accuracy
# if only the main tower handles all batch-normalization updates, which are then shared across
# the towers
ctx = get_current_tower_context()
if ctx is not None and ctx.is_main_training_tower:
non_grad_updates = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
if non_grad_updates:
logger.info("Applying UPDATE_OPS collection from the first tower on cost.")
with tf.control_dependencies(non_grad_updates):
cost = tf.identity(cost, name='cost_with_update')
return cost return cost
...@@ -9,7 +9,7 @@ import re ...@@ -9,7 +9,7 @@ import re
from six.moves import zip, range from six.moves import zip, range
from ..utils import logger from ..utils import logger
from ..utils.naming import SUMMARY_BACKUP_KEYS from ..utils.naming import TOWER_FREEZE_KEYS
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
...@@ -50,8 +50,8 @@ class MultiGPUTrainer(Trainer): ...@@ -50,8 +50,8 @@ class MultiGPUTrainer(Trainer):
ret.append(func()) ret.append(func())
if idx == 0: if idx == 0:
# avoid repeated summary from each device # avoid repeated summary & update_ops from each device
backup = backup_collection(SUMMARY_BACKUP_KEYS) backup = backup_collection(TOWER_FREEZE_KEYS)
restore_collection(backup) restore_collection(backup)
return ret return ret
......
...@@ -24,6 +24,8 @@ INPUTS_KEY = 'INPUTS_METAINFO' ...@@ -24,6 +24,8 @@ INPUTS_KEY = 'INPUTS_METAINFO'
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY] SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS + [tf.GraphKeys.UPDATE_OPS]
# export all upper case variables # export all upper case variables
all_local_names = locals().keys() all_local_names = locals().keys()
__all__ = [x for x in all_local_names if x.isupper()] __all__ = [x for x in all_local_names if x.isupper()]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment