Commit 85303dd2 authored by Yuxin Wu's avatar Yuxin Wu

Split TowerContext to TrainTowerContext and PredictTowerContext to support...

Split TowerContext to TrainTowerContext and PredictTowerContext to support more informative contexts in the future
parent 9fc641e4
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..tfutils.tower import TowerContext from ..tfutils.tower import PredictTowerContext
from .training import GraphBuilder from .training import GraphBuilder
__all__ = ['SimplePredictBuilder'] __all__ = ['SimplePredictBuilder']
...@@ -41,8 +41,8 @@ class SimplePredictBuilder(GraphBuilder): ...@@ -41,8 +41,8 @@ class SimplePredictBuilder(GraphBuilder):
self._ns_name, self._device)) self._ns_name, self._device))
with tf.device(self._device), \ with tf.device(self._device), \
TowerContext( PredictTowerContext(
self._ns_name, is_training=False, vs_name=self._vs_name): self._ns_name, vs_name=self._vs_name):
inputs = input.get_input_tensors() inputs = input.get_input_tensors()
assert isinstance(inputs, (list, tuple)), inputs assert isinstance(inputs, (list, tuple)), inputs
return tower_fn(*inputs) return tower_fn(*inputs)
...@@ -8,9 +8,10 @@ import six ...@@ -8,9 +8,10 @@ import six
import re import re
import pprint import pprint
from six.moves import zip, range from six.moves import zip, range
from contextlib import contextmanager
from ..utils import logger from ..utils import logger
from ..tfutils.tower import TowerContext from ..tfutils.tower import TrainTowerContext
from ..tfutils.gradproc import ScaleGradient from ..tfutils.gradproc import ScaleGradient
from .utils import ( from .utils import (
...@@ -31,6 +32,15 @@ class GraphBuilder(object): ...@@ -31,6 +32,15 @@ class GraphBuilder(object):
pass pass
@contextmanager
def _maybe_reuse_vs(reuse):
if reuse:
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
yield
else:
yield
class DataParallelBuilder(GraphBuilder): class DataParallelBuilder(GraphBuilder):
def __init__(self, towers): def __init__(self, towers):
""" """
...@@ -92,11 +102,11 @@ class DataParallelBuilder(GraphBuilder): ...@@ -92,11 +102,11 @@ class DataParallelBuilder(GraphBuilder):
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t) device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
usevs = use_vs[idx] if use_vs is not None else False usevs = use_vs[idx] if use_vs is not None else False
with tf.device(device), TowerContext( reuse = not usevs and idx > 0
with tf.device(device), _maybe_reuse_vs(reuse), TrainTowerContext(
tower_names[idx], tower_names[idx],
is_training=True, vs_name=tower_names[idx] if usevs else '',
index=idx, index=idx, total=len(towers)):
vs_name=tower_names[idx] if usevs else ''):
if len(str(device)) < 10: # a device function doesn't have good string description if len(str(device)) < 10: # a device function doesn't have good string description
logger.info("Building graph for training tower {} on device {} ...".format(idx, device)) logger.info("Building graph for training tower {} on device {} ...".format(idx, device))
else: else:
......
...@@ -7,7 +7,7 @@ import tensorflow as tf ...@@ -7,7 +7,7 @@ import tensorflow as tf
import six import six
from ..tfutils.common import get_tensors_by_names, get_tf_version_number from ..tfutils.common import get_tensors_by_names, get_tf_version_number
from ..tfutils.tower import TowerContext from ..tfutils.tower import PredictTowerContext
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..utils.argtools import log_once from ..utils.argtools import log_once
...@@ -165,7 +165,7 @@ class OfflinePredictor(OnlinePredictor): ...@@ -165,7 +165,7 @@ class OfflinePredictor(OnlinePredictor):
with self.graph.as_default(): with self.graph.as_default():
input = PlaceholderInput() input = PlaceholderInput()
input.setup(config.inputs_desc) input.setup(config.inputs_desc)
with TowerContext('', is_training=False): with PredictTowerContext(''):
config.tower_func(*input.get_input_tensors()) config.tower_func(*input.get_input_tensors())
input_tensors = get_tensors_by_names(config.input_names) input_tensors = get_tensors_by_names(config.input_names)
......
...@@ -86,7 +86,7 @@ class CollectionGuard(object): ...@@ -86,7 +86,7 @@ class CollectionGuard(object):
name (str): name of the tower name (str): name of the tower
check_diff (bool): whether to check and print about collection change check_diff (bool): whether to check and print about collection change
when leaving this guard. when leaving this guard.
freeze_keys (list): list of keys to freeze freeze_keys (list): list of keys to backup when entering and restore when leaving this guard.
diff_whitelist (list): list of keys to ignore, when check_diff is True. diff_whitelist (list): list of keys to ignore, when check_diff is True.
Defaults to some collections that are normally changed, Defaults to some collections that are normally changed,
including variables, losses, contexts, queue runners. including variables, losses, contexts, queue runners.
......
...@@ -3,14 +3,16 @@ ...@@ -3,14 +3,16 @@
import tensorflow as tf import tensorflow as tf
import six
from six.moves import zip from six.moves import zip
from abc import abstractproperty, abstractmethod, ABCMeta
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC
from .collection import CollectionGuard from .collection import CollectionGuard
from .common import get_tf_version_number, get_op_or_tensor_by_name, get_op_tensor_name from .common import get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper', __all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper',
'TowerTensorHandle', 'TowerTensorHandles'] 'TowerTensorHandle', 'TowerTensorHandles']
...@@ -18,57 +20,34 @@ __all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper', ...@@ -18,57 +20,34 @@ __all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper',
_CurrentTowerContext = None _CurrentTowerContext = None
class TowerContext(object): @six.add_metaclass(ABCMeta)
class BaseTowerContext(object):
""" A context where the current model is built in. """ A context where the current model is built in.
Since TF 1.8, TensorFlow starts to introduce the same concept. Since TF 1.8, TensorFlow starts to introduce the same concept.
""" """
def __init__(self, tower_name, is_training, index=0, vs_name=''): def __init__(self, ns_name, vs_name=''):
""" """
Args: Args:
tower_name (str): The name scope of the tower. ns_name (str): The name scope of the tower.
is_training (bool):
index (int): index of this tower, only used in training.
vs_name (str): Open a new variable scope with this name. vs_name (str): Open a new variable scope with this name.
""" """
self._name = tower_name self._name = ns_name
self._is_training = bool(is_training)
if not self._is_training:
assert index == 0, \
"TowerContext(index) is only valid in training!"
self._index = int(index)
self._vs_name = vs_name self._vs_name = vs_name
if len(vs_name): if len(vs_name):
assert len(tower_name), "TowerContext(vs_name) cannot be used with an empty tower_name!" assert len(ns_name), "TowerContext(vs_name) cannot be used with an empty name!"
self._initial_vs_reuse = tf.get_variable_scope().reuse
if self.has_own_variables:
assert not self._initial_vs_reuse, \
"Cannot create tower {} with reuse=True!".format(tower_name)
self._collection_guard = CollectionGuard( @abstractproperty
self._name,
check_diff=not self.is_main_training_tower,
freeze_keys=self._keys_to_freeze())
@property
def is_main_training_tower(self): def is_main_training_tower(self):
return self.is_training and self._index == 0 pass
@property @abstractproperty
def is_training(self):
return self._is_training
@property
def has_own_variables(self): def has_own_variables(self):
""" """
Whether this tower is supposed to have its own variables. Whether this tower is supposed to have its own variables.
""" """
return self.is_main_training_tower or \ pass
(self.is_training and len(self._vs_name) > 0) or \
(not self.is_training and not self._initial_vs_reuse)
@property @property
def name(self): def name(self):
...@@ -89,59 +68,47 @@ class TowerContext(object): ...@@ -89,59 +68,47 @@ class TowerContext(object):
""" """
return self._collection_guard.get_collection_in_tower(key) return self._collection_guard.get_collection_in_tower(key)
# TODO currently only used in StagingInput
@property
def index(self):
return self._index
@call_only_once @call_only_once
def _get_scopes(self): def _get_scopes(self):
"""
Returns the ns and vs for this tower.
"""
if not len(self._name): if not len(self._name):
# work around https://github.com/tensorflow/tensorflow/issues/14703 # work around https://github.com/tensorflow/tensorflow/issues/14703
return [tf.variable_scope(tf.get_variable_scope())] return [tf.variable_scope(tf.get_variable_scope())]
ret = []
# either the Tower was originally created with reuse, ret = []
# or a training tower without vs has to use reuse.
reuse = (self.is_training and self._index > 0 and not
self.has_own_variables) or self._initial_vs_reuse
if len(self._vs_name): if len(self._vs_name):
ret.append(tf.variable_scope(self._vs_name, reuse=reuse)) ret.append(tf.variable_scope(self._vs_name))
else: else:
if reuse: # caller should have handled reuse outside of TowerContext
ret.append(tf.variable_scope(
tf.get_variable_scope(), reuse=True))
else:
# work around https://github.com/tensorflow/tensorflow/issues/14703
ret.append(tf.variable_scope(tf.get_variable_scope())) ret.append(tf.variable_scope(tf.get_variable_scope()))
# always clear existing ns # TODO check existing ns # always clear existing ns # TODO check existing ns
if len(self._name) and self._name != self._vs_name: if len(self._name) and self._name != self._vs_name:
ret.append(tf.name_scope(self._name + '/')) ret.append(tf.name_scope(self._name + '/'))
return ret return ret
@abstractmethod
def _keys_to_freeze(self): def _keys_to_freeze(self):
if self.is_main_training_tower: pass
return []
if self.is_training:
return [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
# freeze UPDATE_OPS during inference because they should never be used
return [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY, tf.GraphKeys.UPDATE_OPS]
def __enter__(self): def __enter__(self):
global _CurrentTowerContext global _CurrentTowerContext
assert _CurrentTowerContext is None, "Cannot nest TowerContext!" assert _CurrentTowerContext is None, "Cannot nest TowerContext!"
_CurrentTowerContext = self _CurrentTowerContext = self
if self.is_training:
curr_vs = tf.get_variable_scope() self._collection_guard = CollectionGuard(
assert curr_vs.name == '', "In training, cannot nest TowerContext with an existing variable scope!" self._name,
check_diff=not self.is_main_training_tower,
freeze_keys=self._keys_to_freeze())
self._ctxs = self._get_scopes() self._ctxs = self._get_scopes()
self._ctxs.append(self._collection_guard) self._ctxs.append(self._collection_guard)
for c in self._ctxs: for c in self._ctxs:
c.__enter__() c.__enter__()
if get_tf_version_number() >= 1.2:
# check that ns_name is always the same as _name # check that ns_name is always the same as _name
ns = tf.get_default_graph().get_name_scope() ns = tf.get_default_graph().get_name_scope()
assert ns == self._name, \ assert ns == self._name, \
...@@ -167,10 +134,89 @@ class TowerContext(object): ...@@ -167,10 +134,89 @@ class TowerContext(object):
self._name, self._is_training) self._name, self._is_training)
class TrainTowerContext(BaseTowerContext):
is_training = True
def __init__(self, ns_name, vs_name='', index=0, total=1):
"""
Args:
index (int): index of this tower, only used in training.
total (int): total number of towers to be built.
"""
super(TrainTowerContext, self).__init__(ns_name, vs_name)
self.index = int(index)
self.total = int(total)
if self.index > 0:
assert self.total > self.index, "(index, total) = ({}, {})".format(self.index, self.total)
vs = tf.get_variable_scope()
assert vs.name == '', "Cannot nest TrainTowerContext with an existing variable scope!"
if self.has_own_variables:
assert not vs.reuse, \
"Cannot create tower {} under reuse=True!".format(ns_name)
@property
def is_main_training_tower(self):
return self.index == 0
@property
def has_own_variables(self):
return self.index == 0 or len(self._vs_name) > 0
def _keys_to_freeze(self):
if self.index == 0:
return []
return [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY]
class PredictTowerContext(BaseTowerContext):
is_training = False
def __init__(self, ns_name, vs_name=''):
super(PredictTowerContext, self).__init__(ns_name, vs_name)
self._initial_vs_reuse = tf.get_variable_scope().reuse
@property
def has_own_variables(self):
return not self._initial_vs_reuse
@property
def is_main_training_tower(self):
return False
def _keys_to_freeze(self):
# freeze UPDATE_OPS during inference because they should never be used
return [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_OPS_KEY, tf.GraphKeys.UPDATE_OPS]
def get_current_tower_context(): def get_current_tower_context():
return _CurrentTowerContext return _CurrentTowerContext
def TowerContext(tower_name, is_training, vs_name=''):
"""
User-facing API to build a tower manually.
Returns:
A context within which the tower function should be called.
Examples:
.. code-block:: python
with TowerContext('', is_training=True):
# call a tensorpack layer or a tower function
"""
if is_training:
return TrainTowerContext(tower_name, vs_name=vs_name)
else:
return PredictTowerContext(tower_name, vs_name=vs_name)
class TowerFuncWrapper(object): class TowerFuncWrapper(object):
""" """
A wrapper around a tower function (function which builds one tower, i.e. one replicate of the model). A wrapper around a tower function (function which builds one tower, i.e. one replicate of the model).
......
...@@ -13,7 +13,7 @@ from ..utils.argtools import map_arg ...@@ -13,7 +13,7 @@ from ..utils.argtools import map_arg
from ..utils.develop import HIDE_DOC from ..utils.develop import HIDE_DOC
from ..tfutils import get_global_step_var from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator from ..tfutils.distributed import get_distributed_session_creator
from ..tfutils.tower import TowerContext from ..tfutils.tower import TrainTowerContext
from ..input_source import QueueInput from ..input_source import QueueInput
from ..graph_builder.training import ( from ..graph_builder.training import (
...@@ -49,7 +49,7 @@ class SimpleTrainer(SingleCostTrainer): ...@@ -49,7 +49,7 @@ class SimpleTrainer(SingleCostTrainer):
""" """
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
logger.info("Building graph for a single training tower ...") logger.info("Building graph for a single training tower ...")
with TowerContext('', is_training=True): with TrainTowerContext(''):
grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)() grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)()
opt = get_opt_fn() opt = get_opt_fn()
self.train_op = opt.apply_gradients(grads, name='min_op') self.train_op = opt.apply_gradients(grads, name='min_op')
...@@ -359,7 +359,7 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -359,7 +359,7 @@ class HorovodTrainer(SingleCostTrainer):
return averaged_gradients return averaged_gradients
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
with TowerContext('', is_training=True): with TrainTowerContext(''):
grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)() grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)()
grads = self.allreduce(grads) grads = self.allreduce(grads)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment