Commit 67de41d0 authored by Yuxin Wu's avatar Yuxin Wu

Use CollectionGuard to manage collection change in tower

parent 118aae66
......@@ -11,7 +11,6 @@ import numpy as np
from ..utils import logger
from .base import Callback
from ..tfutils.common import get_tensors_by_names
from six.moves import zip
__all__ = ['RunOp', 'RunUpdateOps', 'ProcessTensors', 'DumpTensors', 'DumpTensor']
......@@ -120,7 +119,7 @@ class ProcessTensors(Callback):
self._fn = fn
def _setup_graph(self):
tensors = get_tensors_by_names(self._names)
tensors = self.get_tensors_maybe_in_tower(self._names)
self._fetch = tf.train.SessionRunArgs(fetches=tensors)
def _before_run(self, _):
......
......@@ -172,7 +172,10 @@ class ModelDesc(ModelDescBase):
ctx = get_current_tower_context()
cost = self._build_graph_get_cost(*inputs)
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
if ctx.has_own_variables:
varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
else:
varlist = tf.trainable_variables()
opt = self.get_optimizer()
grads = opt.compute_gradients(
cost, var_list=varlist,
......
......@@ -7,8 +7,6 @@ from contextlib import contextmanager
from ..utils import logger
from ..tfutils.tower import TowerContext, TowerFuncWrapper
from ..tfutils.collection import freeze_collection
from ..utils.naming import TOWER_FREEZE_KEYS
from ..input_source import PlaceholderInput
from .training import GraphBuilder
......@@ -56,11 +54,7 @@ class SimplePredictBuilder(GraphBuilder):
with tf.device(self._device), \
self._maybe_open_vs(), \
TowerContext(
self._ns_name, is_training=False, vs_name=self._vs_name), \
freeze_collection(TOWER_FREEZE_KEYS + [tf.GraphKeys.UPDATE_OPS]):
# also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph.
self._ns_name, is_training=False, vs_name=self._vs_name):
inputs = input.get_input_tensors()
assert isinstance(inputs, (list, tuple)), inputs
return tower_fn(*inputs)
......@@ -92,10 +86,7 @@ class PredictorFactory(object):
"Prediction tower with name '{}' already exists!".format(tower_name)
with tf.device(device), \
TowerContext(tower_name, is_training=False), \
freeze_collection(TOWER_FREEZE_KEYS + [tf.GraphKeys.UPDATE_OPS]):
# also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph.
TowerContext(tower_name, is_training=False):
inputs_desc = self._model.get_inputs_desc()
if input is None:
input = PlaceholderInput()
......
......@@ -10,9 +10,7 @@ from six.moves import zip, range
from ..utils import logger
from ..tfutils.tower import TowerContext
from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import ScaleGradient
from ..utils.naming import TOWER_FREEZE_KEYS
from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable,
......@@ -95,11 +93,6 @@ class DataParallelBuilder(GraphBuilder):
# so these duplicated variables won't be saved by default.
with override_to_local_variable(enable=usevs):
ret.append(func())
if idx == 0:
# avoid duplicated summary & update_ops from each device
backup = backup_collection(TOWER_FREEZE_KEYS)
restore_collection(backup)
return ret
......
......@@ -27,7 +27,10 @@ os.environ['TF_AVGPOOL_USE_CUDNN'] = '1' # issue#8566
try:
import tensorflow as tf # noqa
assert int(tf.__version__.split('.')[0]) >= 1, "TF>=1.0 is required!"
_version = tf.__version__.split('.')
assert int(_version[0]) >= 1, "TF>=1.0 is required!"
if int(_version[1]) < 2:
print("TF<1.2 support will be removed in the future!")
_HAS_TF = True
except ImportError:
_HAS_TF = False
......
......@@ -48,11 +48,13 @@ def regularize_cost(regex, func, name='regularize_cost'):
# because ths vs_name used in inference can be '', therefore the
# variable filter will fail
return tf.constant(0, dtype=tf.float32, name='empty_' + name)
params = tf.trainable_variables()
# If vars are shared, use all of them
# If vars are shared, regularize all of them
# If vars are replicated, only regularize those in the current tower
params = ctx.filter_vars_by_vs_name(params)
if ctx.has_own_variables:
params = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
else:
params = tf.trainable_variables()
G = tf.get_default_graph()
......@@ -93,21 +95,22 @@ def regularize_cost_from_collection(name='regularize_cost'):
Returns:
a scalar tensor, the regularization loss, or None
"""
regularization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
ctx = get_current_tower_context()
if not ctx.is_training:
# Currently cannot build the wd_cost correctly at inference,
# TODO Currently cannot build the wd_cost correctly at inference,
# because ths vs_name used in inference can be '', therefore the
# variable filter will fail
return None
if len(regularization_losses) > 0:
# NOTE: this collection doesn't grow with towers.
# NOTE: this collection doesn't always grow with towers.
# It is only added with variables that are newly created.
if ctx.has_own_variables: # be careful of the first tower (name='')
regularization_losses = ctx.filter_vars_by_vs_name(regularization_losses)
logger.info("Add REGULARIZATION_LOSSES of {} tensors on the total cost.".format(len(regularization_losses)))
reg_loss = tf.add_n(list(regularization_losses), name=name)
losses = ctx.get_collection_in_tower(tf.GraphKeys.REGULARIZATION_LOSSES)
else:
losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
if len(losses) > 0:
logger.info("Add REGULARIZATION_LOSSES of {} tensors on the total cost.".format(len(losses)))
reg_loss = tf.add_n(losses)
return reg_loss
else:
return None
......
......@@ -8,21 +8,27 @@ from copy import copy
import six
from contextlib import contextmanager
from ..utils import logger
from ..utils.argtools import memoized
__all__ = ['backup_collection',
'restore_collection',
'freeze_collection']
def backup_collection(keys):
def backup_collection(keys=None):
"""
Args:
keys (list): list of collection keys to backup
keys (list): list of collection keys to backup.
Defaults to all keys in the graph.
Returns:
dict: the backup
"""
if keys is None:
keys = tf.get_default_graph().get_all_collection_keys()
ret = {}
assert isinstance(keys, (list, tuple))
assert isinstance(keys, (list, tuple, set))
for k in keys:
ret[k] = copy(tf.get_collection(k))
return ret
......@@ -52,3 +58,103 @@ def freeze_collection(keys):
backup = backup_collection(keys)
yield
restore_collection(backup)
@memoized
def get_inverse_graphkeys():
ret = {}
for name in dir(tf.GraphKeys):
if name.startswith('_'):
continue
if name in ['VARIABLES']: # will produce deprecated warning
continue
ret[getattr(tf.GraphKeys, name)] = "tf.GraphKeys.{}".format(name)
return ret
class CollectionGuard(object):
"""
A context to maintain collection change in a tower.
"""
original = None
def __init__(self, name, check_diff,
freeze_keys=[],
diff_whitelist=[
tf.GraphKeys.TRAINABLE_VARIABLES,
tf.GraphKeys.GLOBAL_VARIABLES,
tf.GraphKeys.LOCAL_VARIABLES]):
"""
Args:
name (str): name of the tower
check_diff (bool): whether to test and print about collection change
freeze_keys (list): list of keys to freeze
diff_whitelist (list): list of keys to not print, when check_diff is True
"""
self._name = name
self._check_diff = check_diff
self._whitelist = set(diff_whitelist)
self._freeze_keys = freeze_keys
self._inverse_graphkeys = get_inverse_graphkeys()
def _key_name(self, name):
return self._inverse_graphkeys.get(name, name)
def __enter__(self):
self.original = backup_collection()
self._freeze_backup = backup_collection(self._freeze_keys)
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
return False
new_coll = backup_collection()
if self._check_diff:
self._print_diff(new_coll)
self._restore_freeze(new_coll)
return False
def _print_diff(self, new):
newly_created = []
size_change = []
for k, v in six.iteritems(new):
if k in self._whitelist or k in self._freeze_keys:
continue
if k not in self.original:
newly_created.append(self._key_name(k))
else:
old_v = self.original[k]
if len(old_v) != len(v):
size_change.append((self._key_name(k), len(old_v), len(v)))
if newly_created:
logger.info(
"New collections created in {}: {}".format(
self._name, ', '.join(newly_created)))
if size_change:
logger.info(
"Size of these collections were changed in {}: {}".format(
self._name, ', '.join(
map(lambda t: "({}: {}->{})".format(*t),
size_change))))
def _restore_freeze(self, new):
size_change = []
for k, v in six.iteritems(self._freeze_backup):
newv = new.get(k, [])
if len(v) != len(newv):
size_change.append((self._key_name(k), len(v), len(newv)))
if size_change:
logger.info(
"These collections were modified but restored in {}: {}".format(
self._name, ', '.join(
map(lambda t: "({}: {}->{})".format(*t),
size_change))))
def get_collection_in_tower(self, key):
"""
Get items from this collection that are added in the current tower.
"""
new = set(tf.get_collection(key))
old = set(self.original.get(key, []))
return list(new - old)
......@@ -8,6 +8,8 @@ from six.moves import zip
from ..utils import logger
from ..utils.argtools import call_only_once
from ..utils.naming import TRAIN_TOWER_FREEZE_KEYS, PREDICT_TOWER_FREEZE_KEYS
from .collection import CollectionGuard
from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper',
......@@ -44,6 +46,11 @@ class TowerContext(object):
assert not self._initial_vs_reuse, \
"Cannot create tower {} with reuse=True!".format(tower_name)
self._collection_guard = CollectionGuard(
self._name,
check_diff=not self.is_main_training_tower,
freeze_keys=self._keys_to_freeze())
@property
def is_main_training_tower(self):
return self.is_training and self._index == 0
......@@ -91,6 +98,12 @@ class TowerContext(object):
prefix = self._vs_name + '/'
return [v for v in varlist if v.op.name.startswith(prefix)]
def get_collection_in_tower(self, key):
"""
Get items from this collection that are added in the current tower.
"""
return self._collection_guard.get_collection_in_tower(key)
@property
def index(self):
return self._index
......@@ -117,6 +130,13 @@ class TowerContext(object):
ret.append(tf.name_scope(self._name + '/'))
return ret
def _keys_to_freeze(self):
if self.is_main_training_tower:
return []
if self.is_training:
return TRAIN_TOWER_FREEZE_KEYS
return PREDICT_TOWER_FREEZE_KEYS
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, "Cannot nest TowerContext!"
......@@ -125,6 +145,7 @@ class TowerContext(object):
assert curr_vs.name == '', "Cannot nest TowerContext with an existing variable scope!"
self._ctxs = self._get_scopes()
self._ctxs.append(self._collection_guard)
for c in self._ctxs:
c.__enter__()
......@@ -139,6 +160,12 @@ class TowerContext(object):
def __exit__(self, exc_type, exc_val, exc_tb):
global _CurrentTowerContext
_CurrentTowerContext = None
if not self.has_own_variables:
diff_trainable_vars = self._collection_guard.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
assert len(diff_trainable_vars) == 0, \
"New TRAINABLE_VARIABLES shouldn't be created in {}: ".format(
self._name) + ', '.join([k.name for k in diff_trainable_vars])
for c in self._ctxs[::-1]:
c.__exit__(exc_type, exc_val, exc_tb)
return False
......
......@@ -156,7 +156,10 @@ class SingleCostTrainer(TowerTrainer):
ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
if ctx.has_own_variables:
varlist = ctx.get_collection_in_tower(tf.GraphKeys.TRAINABLE_VARIABLES)
else:
varlist = tf.trainable_variables()
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
......
......@@ -12,4 +12,7 @@ MOVING_SUMMARY_OPS_KEY = 'MOVING_SUMMARY_OPS'
SUMMARY_BACKUP_KEYS = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS
TRAIN_TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS
PREDICT_TOWER_FREEZE_KEYS = SUMMARY_BACKUP_KEYS + [tf.GraphKeys.UPDATE_OPS]
# also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment