Commit 85303dd2 authored by Yuxin Wu's avatar Yuxin Wu

Split TowerContext to TrainTowerContext and PredictTowerContext to support...

Split TowerContext to TrainTowerContext and PredictTowerContext to support more informative contexts in the future
parent 9fc641e4
......@@ -4,7 +4,7 @@
import tensorflow as tf
from ..utils import logger
from ..tfutils.tower import TowerContext
from ..tfutils.tower import PredictTowerContext
from .training import GraphBuilder
__all__ = ['SimplePredictBuilder']
......@@ -41,8 +41,8 @@ class SimplePredictBuilder(GraphBuilder):
self._ns_name, self._device))
with tf.device(self._device), \
TowerContext(
self._ns_name, is_training=False, vs_name=self._vs_name):
PredictTowerContext(
self._ns_name, vs_name=self._vs_name):
inputs = input.get_input_tensors()
assert isinstance(inputs, (list, tuple)), inputs
return tower_fn(*inputs)
......@@ -8,9 +8,10 @@ import six
import re
import pprint
from six.moves import zip, range
from contextlib import contextmanager
from ..utils import logger
from ..tfutils.tower import TowerContext
from ..tfutils.tower import TrainTowerContext
from ..tfutils.gradproc import ScaleGradient
from .utils import (
......@@ -31,6 +32,15 @@ class GraphBuilder(object):
pass
@contextmanager
def _maybe_reuse_vs(reuse):
if reuse:
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
yield
else:
yield
class DataParallelBuilder(GraphBuilder):
def __init__(self, towers):
"""
......@@ -92,11 +102,11 @@ class DataParallelBuilder(GraphBuilder):
for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
usevs = use_vs[idx] if use_vs is not None else False
with tf.device(device), TowerContext(
reuse = not usevs and idx > 0
with tf.device(device), _maybe_reuse_vs(reuse), TrainTowerContext(
tower_names[idx],
is_training=True,
index=idx,
vs_name=tower_names[idx] if usevs else ''):
vs_name=tower_names[idx] if usevs else '',
index=idx, total=len(towers)):
if len(str(device)) < 10: # a device function doesn't have good string description
logger.info("Building graph for training tower {} on device {} ...".format(idx, device))
else:
......
......@@ -7,7 +7,7 @@ import tensorflow as tf
import six
from ..tfutils.common import get_tensors_by_names, get_tf_version_number
from ..tfutils.tower import TowerContext
from ..tfutils.tower import PredictTowerContext
from ..input_source import PlaceholderInput
from ..utils.develop import log_deprecated
from ..utils.argtools import log_once
......@@ -165,7 +165,7 @@ class OfflinePredictor(OnlinePredictor):
with self.graph.as_default():
input = PlaceholderInput()
input.setup(config.inputs_desc)
with TowerContext('', is_training=False):
with PredictTowerContext(''):
config.tower_func(*input.get_input_tensors())
input_tensors = get_tensors_by_names(config.input_names)
......
......@@ -86,7 +86,7 @@ class CollectionGuard(object):
name (str): name of the tower
check_diff (bool): whether to check and print about collection change
when leaving this guard.
freeze_keys (list): list of keys to freeze
freeze_keys (list): list of keys to backup when entering and restore when leaving this guard.
diff_whitelist (list): list of keys to ignore, when check_diff is True.
Defaults to some collections that are normally changed,
including variables, losses, contexts, queue runners.
......
......@@ -3,14 +3,16 @@
import tensorflow as tf
import six
from six.moves import zip
from abc import abstractproperty, abstractmethod, ABCMeta
from ..utils import logger
from ..utils.argtools import call_only_once
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from ..utils.develop import HIDE_DOC
from .collection import CollectionGuard
from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name
from .common import get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper',
'TowerTensorHandle', 'TowerTensorHandles']
......@@ -18,57 +20,34 @@ __all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper',
_CurrentTowerContext = None
class TowerContext(object):
@six.add_metaclass(ABCMeta)
class BaseTowerContext(object):
""" A context where the current model is built in.
Since TF 1.8, TensorFlow starts to introduce the same concept.
"""
def __init__(self, tower_name, is_training, index=0, vs_name=''):
def __init__(self, ns_name, vs_name=''):
"""
Args:
tower_name (str): The name scope of the tower.
is_training (bool):
index (int): index of this tower, only used in training.
ns_name (str): The name scope of the tower.
vs_name (str): Open a new variable scope with this name.
"""
self._name = tower_name
self._is_training = bool(is_training)
if not self._is_training:
assert index == 0, \
"TowerContext(index) is only valid in training!"
self._name = ns_name
self._index = int(index)
self._vs_name = vs_name
if len(vs_name):
assert len(tower_name), "TowerContext(vs_name) cannot be used with an empty tower_name!"
self._initial_vs_reuse = tf.get_variable_scope().reuse
if self.has_own_variables:
assert not self._initial_vs_reuse, \
"Cannot create tower {} with reuse=True!".format(tower_name)
assert len(ns_name), "TowerContext(vs_name) cannot be used with an empty name!"
self._collection_guard = CollectionGuard(
self._name,
check_diff=not self.is_main_training_tower,
freeze_keys=self._keys_to_freeze())
@property
@abstractproperty
def is_main_training_tower(self):
return self.is_training and self._index == 0
pass
@property
def is_training(self):
return self._is_training
@property
@abstractproperty
def has_own_variables(self):
"""
Whether this tower is supposed to have its own variables.
"""
return self.is_main_training_tower or \
(self.is_training and len(self._vs_name) > 0) or \
(not self.is_training and not self._initial_vs_reuse)
pass
@property
def name(self):
......@@ -89,59 +68,47 @@ class TowerContext(object):
"""
return self._collection_guard.get_collection_in_tower(key)
# TODO currently only used in StagingInput
@property
def index(self):
return self._index
@call_only_once
def _get_scopes(self):
"""
Returns the ns and vs for this tower.
"""
if not len(self._name):
# work around https://github.com/tensorflow/tensorflow/issues/14703
return [tf.variable_scope(tf.get_variable_scope())]
ret = []
# either the Tower was originally created with reuse,
# or a training tower without vs has to use reuse.
reuse = (self.is_training and self._index > 0 and not
self.has_own_variables) or self._initial_vs_reuse
ret = []
if len(self._vs_name):
ret.append(tf.variable_scope(self._vs_name, reuse=reuse))
ret.append(tf.variable_scope(self._vs_name))
else:
if reuse:
ret.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True))
else:
# work around https://github.com/tensorflow/tensorflow/issues/14703
# caller should have handled reuse outside of TowerContext
ret.append(tf.variable_scope(tf.get_variable_scope()))
# always clear existing ns # TODO check existing ns
if len(self._name) and self._name != self._vs_name:
ret.append(tf.name_scope(self._name + '/'))
return ret
@abstractmethod
def _keys_to_freeze(self):
if self.is_main_training_tower:
return []
if self.is_training:
return [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
# freeze UPDATE_OPS during inference because they should never be used
return [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY, tf.GraphKeys.UPDATE_OPS]
pass
def __enter__(self):
global _CurrentTowerContext
assert _CurrentTowerContext is None, "Cannot nest TowerContext!"
_CurrentTowerContext = self
if self.is_training:
curr_vs = tf.get_variable_scope()
assert curr_vs.name == '', "In training, cannot nest TowerContext with an existing variable scope!"
self._collection_guard = CollectionGuard(
self._name,
check_diff=not self.is_main_training_tower,
freeze_keys=self._keys_to_freeze())
self._ctxs = self._get_scopes()
self._ctxs.append(self._collection_guard)
for c in self._ctxs:
c.__enter__()
if get_tf_version_number() >= 1.2:
# check that ns_name is always the same as _name
ns = tf.get_default_graph().get_name_scope()
assert ns == self._name, \
......@@ -167,10 +134,89 @@ class TowerContext(object):
self._name, self._is_training)
class TrainTowerContext(BaseTowerContext):
is_training = True
def __init__(self, ns_name, vs_name='', index=0, total=1):
"""
Args:
index (int): index of this tower, only used in training.
total (int): total number of towers to be built.
"""
super(TrainTowerContext, self).__init__(ns_name, vs_name)
self.index = int(index)
self.total = int(total)
if self.index > 0:
assert self.total > self.index, "(index, total) = ({}, {})".format(self.index, self.total)
vs = tf.get_variable_scope()
assert vs.name == '', "Cannot nest TrainTowerContext with an existing variable scope!"
if self.has_own_variables:
assert not vs.reuse, \
"Cannot create tower {} under reuse=True!".format(ns_name)
@property
def is_main_training_tower(self):
return self.index == 0
@property
def has_own_variables(self):
return self.index == 0 or len(self._vs_name) > 0
def _keys_to_freeze(self):
if self.index == 0:
return []
return [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
class PredictTowerContext(BaseTowerContext):
is_training = False
def __init__(self, ns_name, vs_name=''):
super(PredictTowerContext, self).__init__(ns_name, vs_name)
self._initial_vs_reuse = tf.get_variable_scope().reuse
@property
def has_own_variables(self):
return not self._initial_vs_reuse
@property
def is_main_training_tower(self):
return False
def _keys_to_freeze(self):
# freeze UPDATE_OPS during inference because they should never be used
return [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY, tf.GraphKeys.UPDATE_OPS]
def get_current_tower_context():
return _CurrentTowerContext
def TowerContext(tower_name, is_training, vs_name=''):
"""
User-facing API to build a tower manually.
Returns:
A context within which the tower function should be called.
Examples:
.. code-block:: python
with TowerContext('', is_training=True):
# call a tensorpack layer or a tower function
"""
if is_training:
return TrainTowerContext(tower_name, vs_name=vs_name)
else:
return PredictTowerContext(tower_name, vs_name=vs_name)
class TowerFuncWrapper(object):
"""
A wrapper around a tower function (function which builds one tower, i.e. one replicate of the model).
......
......@@ -13,7 +13,7 @@ from ..utils.argtools import map_arg
from ..utils.develop import HIDE_DOC
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..tfutils.tower import TowerContext
from ..tfutils.tower import TrainTowerContext
from ..input_source import QueueInput
from ..graph_builder.training import (
......@@ -49,7 +49,7 @@ class SimpleTrainer(SingleCostTrainer):
"""
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
logger.info("Building graph for a single training tower ...")
with TowerContext('', is_training=True):
with TrainTowerContext(''):
grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)()
opt = get_opt_fn()
self.train_op = opt.apply_gradients(grads, name='min_op')
......@@ -359,7 +359,7 @@ class HorovodTrainer(SingleCostTrainer):
return averaged_gradients
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
with TowerContext('', is_training=True):
with TrainTowerContext(''):
grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)()
grads = self.allreduce(grads)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment