Commit 67de41d0 authored by Yuxin Wu's avatar Yuxin Wu

Use CollectionGuard to manage collection change in tower

parent 118aae66
...@@ -11,7 +11,6 @@ import numpy as np ...@@ -11,7 +11,6 @@ import numpy as np
from ..utils import logger from ..utils import logger
from .base import Callback from .base import Callback
from ..tfutils.common import get_tensors_by_names
from six.moves import zip from six.moves import zip
__all__ = ['RunOp', 'RunUpdateOps', 'ProcessTensors', 'DumpTensors', 'DumpTensor'] __all__ = ['RunOp', 'RunUpdateOps', 'ProcessTensors', 'DumpTensors', 'DumpTensor']
...@@ -120,7 +119,7 @@ class ProcessTensors(Callback): ...@@ -120,7 +119,7 @@ class ProcessTensors(Callback):
self._fn = fn self._fn = fn
def _setup_graph(self): def _setup_graph(self):
tensors = get_tensors_by_names(self._names) tensors = self.get_tensors_maybe_in_tower(self._names)
self._fetch = tf.train.SessionRunArgs(fetches=tensors) self._fetch = tf.train.SessionRunArgs(fetches=tensors)
def _before_run(self, _): def _before_run(self, _):
......
...@@ -172,7 +172,10 @@ class ModelDesc(ModelDescBase): ...@@ -172,7 +172,10 @@ class ModelDesc(ModelDescBase):
ctx = get_current_tower_context() ctx = get_current_tower_context()
cost = self._build_graph_get_cost(*inputs) cost = self._build_graph_get_cost(*inputs)
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables()) if ctx.has_own_variables:
varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
else:
varlist = tf.trainable_variables()
opt = self.get_optimizer() opt = self.get_optimizer()
grads = opt.compute_gradients( grads = opt.compute_gradients(
cost, var_list=varlist, cost, var_list=varlist,
......
...@@ -7,8 +7,6 @@ from contextlib import contextmanager ...@@ -7,8 +7,6 @@ from contextlib import contextmanager
from ..utils import logger from ..utils import logger
from ..tfutils.tower import TowerContext, TowerFuncWrapper from ..tfutils.tower import TowerContext, TowerFuncWrapper
from ..tfutils.collection import freeze_collection
from ..utils.naming import TOWER_FREEZE_KEYS
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from .training import GraphBuilder from .training import GraphBuilder
...@@ -56,11 +54,7 @@ class SimplePredictBuilder(GraphBuilder): ...@@ -56,11 +54,7 @@ class SimplePredictBuilder(GraphBuilder):
with tf.device(self._device), \ with tf.device(self._device), \
self._maybe_open_vs(), \ self._maybe_open_vs(), \
TowerContext( TowerContext(
self._ns_name, is_training=False, vs_name=self._vs_name), \ self._ns_name, is_training=False, vs_name=self._vs_name):
freeze_collection(TOWER_FREEZE_KEYS + [tf.GraphKeys.UPDATE_OPS]):
# also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph.
inputs = input.get_input_tensors() inputs = input.get_input_tensors()
assert isinstance(inputs, (list, tuple)), inputs assert isinstance(inputs, (list, tuple)), inputs
return tower_fn(*inputs) return tower_fn(*inputs)
...@@ -92,10 +86,7 @@ class PredictorFactory(object): ...@@ -92,10 +86,7 @@ class PredictorFactory(object):
"Prediction tower with name '{}' already exists!".format(tower_name) "Prediction tower with name '{}' already exists!".format(tower_name)
with tf.device(device), \ with tf.device(device), \
TowerContext(tower_name, is_training=False), \ TowerContext(tower_name, is_training=False):
freeze_collection(TOWER_FREEZE_KEYS + [tf.GraphKeys.UPDATE_OPS]):
# also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph.
inputs_desc = self._model.get_inputs_desc() inputs_desc = self._model.get_inputs_desc()
if input is None: if input is None:
input = PlaceholderInput() input = PlaceholderInput()
......
...@@ -10,9 +10,7 @@ from six.moves import zip, range ...@@ -10,9 +10,7 @@ from six.moves import zip, range
from ..utils import logger from ..utils import logger
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.common import get_tf_version_number from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import ScaleGradient from ..tfutils.gradproc import ScaleGradient
from ..utils.naming import TOWER_FREEZE_KEYS
from .utils import ( from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable, LeastLoadedDeviceSetter, override_to_local_variable,
...@@ -95,11 +93,6 @@ class DataParallelBuilder(GraphBuilder): ...@@ -95,11 +93,6 @@ class DataParallelBuilder(GraphBuilder):
# so these duplicated variables won't be saved by default. # so these duplicated variables won't be saved by default.
with override_to_local_variable(enable=usevs): with override_to_local_variable(enable=usevs):
ret.append(func()) ret.append(func())
if idx == 0:
# avoid duplicated summary & update_ops from each device
backup = backup_collection(TOWER_FREEZE_KEYS)
restore_collection(backup)
return ret return ret
......
...@@ -27,7 +27,10 @@ os.environ['TF_AVGPOOL_USE_CUDNN'] = '1' # issue#8566 ...@@ -27,7 +27,10 @@ os.environ['TF_AVGPOOL_USE_CUDNN'] = '1' # issue#8566
try: try:
import tensorflow as tf # noqa import tensorflow as tf # noqa
assert int(tf.__version__.split('.')[0]) >= 1, "TF>=1.0 is required!" _version = tf.__version__.split('.')
assert int(_version[0]) >= 1, "TF>=1.0 is required!"
if int(_version[1]) < 2:
print("TF<1.2 support will be removed in the future!")
_HAS_TF = True _HAS_TF = True
except ImportError: except ImportError:
_HAS_TF = False _HAS_TF = False
......
...@@ -48,11 +48,13 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -48,11 +48,13 @@ def regularize_cost(regex, func, name='regularize_cost'):
# because ths vs_name used in inference can be '', therefore the # because ths vs_name used in inference can be '', therefore the
# variable filter will fail # variable filter will fail
return tf.constant(0, dtype=tf.float32, name='empty_' + name) return tf.constant(0, dtype=tf.float32, name='empty_' + name)
params = tf.trainable_variables()
# If vars are shared, use all of them # If vars are shared, regularize all of them
# If vars are replicated, only regularize those in the current tower # If vars are replicated, only regularize those in the current tower
params = ctx.filter_vars_by_vs_name(params) if ctx.has_own_variables:
params = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
else:
params = tf.trainable_variables()
G = tf.get_default_graph() G = tf.get_default_graph()
...@@ -93,21 +95,22 @@ def regularize_cost_from_collection(name='regularize_cost'): ...@@ -93,21 +95,22 @@ def regularize_cost_from_collection(name='regularize_cost'):
Returns: Returns:
a scalar tensor, the regularization loss, or None a scalar tensor, the regularization loss, or None
""" """
regularization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
ctx = get_current_tower_context() ctx = get_current_tower_context()
if not ctx.is_training: if not ctx.is_training:
# Currently cannot build the wd_cost correctly at inference, # TODO Currently cannot build the wd_cost correctly at inference,
# because ths vs_name used in inference can be '', therefore the # because ths vs_name used in inference can be '', therefore the
# variable filter will fail # variable filter will fail
return None return None
if len(regularization_losses) > 0: # NOTE: this collection doesn't always grow with towers.
# NOTE: this collection doesn't grow with towers. # It is only added with variables that are newly created.
# It is only added with variables that are newly created. if ctx.has_own_variables: # be careful of the first tower (name='')
if ctx.has_own_variables: # be careful of the first tower (name='') losses = ctx.get_collection_in_tower(tf.GraphKeys.REGULARIZATION_LOSSES)
regularization_losses = ctx.filter_vars_by_vs_name(regularization_losses) else:
logger.info("Add REGULARIZATION_LOSSES of {} tensors on the total cost.".format(len(regularization_losses))) losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
reg_loss = tf.add_n(list(regularization_losses), name=name) if len(losses) > 0:
logger.info("Add REGULARIZATION_LOSSES of {} tensors on the total cost.".format(len(losses)))
reg_loss = tf.add_n(losses)
return reg_loss return reg_loss
else: else:
return None return None
......
...@@ -8,21 +8,27 @@ from copy import copy ...@@ -8,21 +8,27 @@ from copy import copy
import six import six
from contextlib import contextmanager from contextlib import contextmanager
from ..utils import logger
from ..utils.argtools import memoized
__all__ = ['backup_collection', __all__ = ['backup_collection',
'restore_collection', 'restore_collection',
'freeze_collection'] 'freeze_collection']
def backup_collection(keys): def backup_collection(keys=None):
""" """
Args: Args:
keys (list): list of collection keys to backup keys (list): list of collection keys to backup.
Defaults to all keys in the graph.
Returns: Returns:
dict: the backup dict: the backup
""" """
if keys is None:
keys = tf.get_default_graph().get_all_collection_keys()
ret = {} ret = {}
assert isinstance(keys, (list, tuple)) assert isinstance(keys, (list, tuple, set))
for k in keys: for k in keys:
ret[k] = copy(tf.get_collection(k)) ret[k] = copy(tf.get_collection(k))
return ret return ret
...@@ -52,3 +58,103 @@ def freeze_collection(keys): ...@@ -52,3 +58,103 @@ def freeze_collection(keys):
backup = backup_collection(keys) backup = backup_collection(keys)
yield yield
restore_collection(backup) restore_collection(backup)
@memoized
def get_inverse_graphkeys():
ret = {}
for name in dir(tf.GraphKeys):
if name.startswith('_'):
continue
if name in ['VARIABLES']: # will produce deprecated warning
continue
ret[getattr(tf.GraphKeys, name)] = "tf.GraphKeys.{}".format(name)
return ret
class CollectionGuard(object):
"""
A context to maintain collection change in a tower.
"""
original = None
def __init__(self, name, check_diff,
freeze_keys=[],
diff_whitelist=[
tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys.GLOBAL_VARIABLES,
tf.GraphKeys.LOCAL_VARIABLES]):
"""
Args:
name (str): name of the tower
check_diff (bool): whether to test and print about collection change
freeze_keys (list): list of keys to freeze
diff_whitelist (list): list of keys to not print, when check_diff is True
"""
self._name = name
self._check_diff = check_diff
self._whitelist = set(diff_whitelist)
self._freeze_keys = freeze_keys
self._inverse_graphkeys = get_inverse_graphkeys()
def _key_name(self, name):
return self._inverse_graphkeys.get(name, name)
def __enter__(self):
self.original = backup_collection()
self._freeze_backup = backup_collection(self._freeze_keys)
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
return False
new_coll = backup_collection()
if self._check_diff:
self._print_diff(new_coll)
self._restore_freeze(new_coll)
return False
def _print_diff(self, new):
newly_created = []
size_change = []
for k, v in six.iteritems(new):
if k in self._whitelist or k in self._freeze_keys:
continue
if k not in self.original:
newly_created.append(self._key_name(k))
else:
old_v = self.original[k]
if len(old_v) != len(v):
size_change.append((self._key_name(k), len(old_v), len(v)))
if newly_created:
logger.info(
"New collections created in {}: {}".format(
self._name, ', '.join(newly_created)))
if size_change:
logger.info(
"Size of these collections were changed in {}: {}".format(
self._name, ', '.join(
map(lambda t: "({}: {}->{})".format(*t),
size_change))))
def _restore_freeze(self, new):
size_change = []
for k, v in six.iteritems(self._freeze_backup):
newv = new.get(k, [])
if len(v) != len(newv):
size_change.append((self._key_name(k), len(v), len(newv)))
if size_change:
logger.info(
"These collections were modified but restored in {}: {}".format(
self._name, ', '.join(
map(lambda t: "({}: {}->{})".format(*t),
size_change))))
def get_collection_in_tower(self, key):
"""
Get items from this collection that are added in the current tower.
"""
new = set(tf.get_collection(key))
old = set(self.original.get(key, []))
return list(new - old)
...@@ -8,6 +8,8 @@ from six.moves import zip ...@@ -8,6 +8,8 @@ from six.moves import zip
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once
from ..utils.naming import TRAIN_TOWER_FREEZE_KEYS, PREDICT_TOWER_FREEZE_KEYS
from .collection import CollectionGuard
from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper', __all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper',
...@@ -44,6 +46,11 @@ class TowerContext(object): ...@@ -44,6 +46,11 @@ class TowerContext(object):
assert not self._initial_vs_reuse, \ assert not self._initial_vs_reuse, \
"Cannot create tower {} with reuse=True!".format(tower_name) "Cannot create tower {} with reuse=True!".format(tower_name)
self._collection_guard = CollectionGuard(
self._name,
check_diff=not self.is_main_training_tower,
freeze_keys=self._keys_to_freeze())
@property @property
def is_main_training_tower(self): def is_main_training_tower(self):
return self.is_training and self._index == 0 return self.is_training and self._index == 0
...@@ -91,6 +98,12 @@ class TowerContext(object): ...@@ -91,6 +98,12 @@ class TowerContext(object):
prefix = self._vs_name + '/' prefix = self._vs_name + '/'
return [v for v in varlist if v.op.name.startswith(prefix)] return [v for v in varlist if v.op.name.startswith(prefix)]
def get_collection_in_tower(self, key):
"""
Get items from this collection that are added in the current tower.
"""
return self._collection_guard.get_collection_in_tower(key)
@property @property
def index(self): def index(self):
return self._index return self._index
...@@ -117,6 +130,13 @@ class TowerContext(object): ...@@ -117,6 +130,13 @@ class TowerContext(object):
ret.append(tf.name_scope(self._name + '/')) ret.append(tf.name_scope(self._name + '/'))
return ret return ret
def _keys_to_freeze(self):
if self.is_main_training_tower:
return []
if self.is_training:
return TRAIN_TOWER_FREEZE_KEYS
return PREDICT_TOWER_FREEZE_KEYS
def __enter__(self): def __enter__(self):
global _CurrentTowerContext global _CurrentTowerContext
assert _CurrentTowerContext is None, "Cannot nest TowerContext!" assert _CurrentTowerContext is None, "Cannot nest TowerContext!"
...@@ -125,6 +145,7 @@ class TowerContext(object): ...@@ -125,6 +145,7 @@ class TowerContext(object):
assert curr_vs.name == '', "Cannot nest TowerContext with an existing variable scope!" assert curr_vs.name == '', "Cannot nest TowerContext with an existing variable scope!"
self._ctxs = self._get_scopes() self._ctxs = self._get_scopes()
self._ctxs.append(self._collection_guard)
for c in self._ctxs: for c in self._ctxs:
c.__enter__() c.__enter__()
...@@ -139,6 +160,12 @@ class TowerContext(object): ...@@ -139,6 +160,12 @@ class TowerContext(object):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext global _CurrentTowerContext
_CurrentTowerContext = None _CurrentTowerContext = None
if not self.has_own_variables:
diff_trainable_vars = self._collection_guard.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
assert len(diff_trainable_vars) == 0, \
"New TRAINABLE_VARIABLES shouldn't be created in {}: ".format(
self._name) + ', '.join([k.name for k in diff_trainable_vars])
for c in self._ctxs[::-1]: for c in self._ctxs[::-1]:
c.__exit__(exc_type, exc_val, exc_tb) c.__exit__(exc_type, exc_val, exc_tb)
return False return False
......
...@@ -156,7 +156,10 @@ class SingleCostTrainer(TowerTrainer): ...@@ -156,7 +156,10 @@ class SingleCostTrainer(TowerTrainer):
ctx = get_current_tower_context() ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors()) cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables()) if ctx.has_own_variables:
varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
else:
varlist = tf.trainable_variables()
opt = get_opt_fn() opt = get_opt_fn()
grads = opt.compute_gradients( grads = opt.compute_gradients(
cost, var_list=varlist, cost, var_list=varlist,
......
...@@ -12,4 +12,7 @@ MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS' ...@@ -12,4 +12,7 @@ MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS'
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY] SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS TRAIN_TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS
PREDICT_TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS + [tf.GraphKeys.UPDATE_OPS]
# also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment