Commit 92f159b6 authored by Yuxin Wu's avatar Yuxin Wu

update docs on trainer

parent e78e2e1e
PLEASE finish reading to show some respect to the authors.
An issue has to be one of the following: An issue has to be one of the following:
- [ ] Unexpected Problems / Potential Bugs - Unexpected Problems / Potential Bugs
- [ ] Feature Requests - Feature Requests
- [ ] Questions on Using/Understanding Tensorpack - Questions on Using/Understanding Tensorpack
## For any unexpected problems, __PLEASE ALWAYS INCLUDE__: ## For any unexpected problems, __PLEASE ALWAYS INCLUDE__:
1. What you did: 1. What you did:
...@@ -28,7 +30,7 @@ About efficiency issues, PLEASE first read http://tensorpack.readthedocs.io/en/l ...@@ -28,7 +30,7 @@ About efficiency issues, PLEASE first read http://tensorpack.readthedocs.io/en/l
(See http://tensorpack.readthedocs.io/en/latest/tutorial/index.html#extend-tensorpack). (See http://tensorpack.readthedocs.io/en/latest/tutorial/index.html#extend-tensorpack).
It does not have to be added to Tensorpack unless you have a good reason. It does not have to be added to Tensorpack unless you have a good reason.
+ "Could you improve/implement an example/paper ?" + "Could you improve/implement an example/paper ?"
-- the answer is: we have no plans to do so. We don't take feature requests for -- The answer is: we have no plans to do so. We don't take feature requests for
examples or implement a paper for you. If you don't know how to do it, you may ask a usage question. examples or implement a paper for you. If you don't know how to do it, you may ask a usage question.
## Usage Questions: ## Usage Questions:
...@@ -36,7 +38,7 @@ About efficiency issues, PLEASE first read http://tensorpack.readthedocs.io/en/l ...@@ -36,7 +38,7 @@ About efficiency issues, PLEASE first read http://tensorpack.readthedocs.io/en/l
+ Read the [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html#user-tutorials) first. + Read the [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html#user-tutorials) first.
+ We answer "HOW to do X with Tensorpack" for a well-defined X. + We answer "HOW to do X with Tensorpack" for a well-defined X.
We also answer "HOW/WHY Tensorpack does X" for some X that Tensorpack or its examples are doing. We also answer "HOW/WHY Tensorpack does X" for some X that Tensorpack or its examples are doing.
We don't answer general machine learning questions,
such as "why my training doesn't converge", "what networks to use" or "I don't understand the paper". We don't answer general machine learning questions, such as "why my training doesn't converge", "what networks to use" or "I don't understand the paper".
You can also use gitter (https://gitter.im/tensorpack/users) for more casual discussions. You can also use gitter (https://gitter.im/tensorpack/users) for more casual discussions.
...@@ -13,43 +13,58 @@ But some basic knowledge of how they work is useful: ...@@ -13,43 +13,58 @@ But some basic knowledge of how they work is useful:
### Tower Trainer ### Tower Trainer
[TowerTrainer](../modules/train.html#tensorpack.train.TowerTrainer)
is a trainer that uses "tower function" to build models.
All existing trainers in tensorpack are subclass of ``TowerTrainer``,
because this concept is able to cover most types of neural-network training tasks.
#### What is Tower Function
Following the terminology in TensorFlow, Following the terminology in TensorFlow,
a __tower function__ is a callable that takes input tensors and adds __one replicate__ of the model to the graph. a __tower function__ is a callable that takes input tensors and adds __one replicate__ of the model to the graph.
Most types of neural-network training could be described with this concept.
The concept of tower is used mainly to support:
The concept of tower is used mainly to support:
1. Data-parallel multi-GPU training, where a replicate is built on each GPU. 1. Data-parallel multi-GPU training, where a replicate is built on each GPU.
2. Graph construction for inference, where a replicate is built under inference mode. 2. Graph construction for inference, where a replicate is built under inference mode.
A user needs to provide a tower function to use `TowerTrainer`. A user needs to provide a tower function to use `TowerTrainer`.
In particular, when working with the `ModelDesc` interface, the `build_graph` method will be the tower function. In particular, when working with the `ModelDesc` interface, the `build_graph`
method will be part of the tower function.
The tower function needs to follow some conventions: #### Rules of Tower Function
1. __It might get called multiple times__ for data-parallel training or inference. The tower function needs to follow some rules:
* Therefore, to use a tensorflow-hub module, you need to initialize the
1. __It may get called multiple times__ for data-parallel training or inference. As a result:
* You'll need to be careful when modifying global states, e.g.
adding ops to collections, setting attributes of a model instance.
* To use a tensorflow-hub module, you need to initialize the
module outside the tower function, and call the module inside the tower function. module outside the tower function, and call the module inside the tower function.
2. It has to respect variable collections: 2. It must __respect variable collections__:
* (Required) Only put variables __trainable by gradient descent__ into `TRAINABLE_VARIABLES`. * (Required) Only put variables __trainable by gradient descent__ into `TRAINABLE_VARIABLES`.
* (Recommended) Put non-trainable variables that need to be used in inference into `MODEL_VARIABLES`. * (Recommended) Put non-trainable variables that need to be used in inference into `MODEL_VARIABLES`.
3. It has to respect variable scopes: 3. It must __respect variable scopes__:
* The name of any trainable variables created in the function must be like "variable_scope_name/custom/name". * The name of any trainable variables created in the function must be like "variable_scope_name/custom/scopes/name".
Don't depend on name_scope's name. Don't use variable_scope's name twice. Don't depend on name_scope's name. Don't use variable_scope's name twice.
* The creation of any trainable variables must respect __reuse__ variable scope. * The creation of any trainable variables must __respect reuse__ variable scope.
To respect variable reuse, use `tf.get_variable` instead of `tf.Variable` in the function. To respect variable reuse (i.e. sharing), use `tf.get_variable` instead of `tf.Variable` in the function.
On the other hand, for non-trainable variables, it's OK to use
`tf.Variable` to ensure creation of new variables in each tower even when `reuse=True`. On the other hand, for a non-trainable variable, it may be desirable to not reuse it between towers.
4. It will always be called under a `TowerContext`, which can be accessed by `get_current_tower_context()`. In this case, `tf.Variable` can be used to ensure creation of new variables in each tower even when `reuse=True`.
The context contains information about training/inference mode, reuse, etc. 4. It cannot create scopes or variables containing the name 'tower', as it is
5. It cannot create scopes or variables containing the name 'tower', as it is
reserved for special use. reserved for special use.
These conventions are easy to follow, and most layer wrappers (e.g., These conventions are easy to follow, and most layer wrappers (e.g.,
tf.layers/slim/tensorlayer) do follow them. Note that certain Keras layers do not tf.layers/slim/tensorlayer) do follow them. Note that certain Keras layers do not
follow these conventions and will need some workarounds if used within tensorpack. follow these conventions and will need some workarounds if used within tensorpack.
It's possible to write ones that are not, but all existing trainers in #### What You Can Do Inside Tower Function
tensorpack are subclass of [TowerTrainer](../modules/train.html#tensorpack.train.TowerTrainer). 1. Call any symbolic functions as long as they follow the above rules.
2. The function will be called under a
[TowerContext](../modules/tfutils.html#tensorpack.tfutils.tower.BaseTowerContext),
which can be accessed by [get_current_tower_context()](../modules/tfutils.html#tensorpack.tfutils.tower.get_current_tower_context).
The context contains information about training/inference mode, scope name, etc.
### MultiGPU Trainers ### MultiGPU Trainers
...@@ -62,17 +77,17 @@ It takes only one line of code change to use them, e.g. `trainer=SyncMultiGPUTra ...@@ -62,17 +77,17 @@ It takes only one line of code change to use them, e.g. `trainer=SyncMultiGPUTra
Note some __common problems__ when using these trainers: Note some __common problems__ when using these trainers:
1. In each iteration, all GPUs (all replicates of the model) take tensors from the `InputSource`, 1. In each iteration, instead of taking one tensor for all GPUs and split,
instead of taking one for all and split. all GPUs take tensors from the `InputSource`.
So the total batch size would become ``(batch size of InputSource) * #GPU``. So the total batch size would become ``(batch size of InputSource) * #GPU``.
Splitting a tensor for data-parallel training makes no sense at all. First, why Splitting a tensor for data-parallel training makes no sense at all. First,
wasting time in concatenating into large batches and then split them? it wastes time because typically data is concatenated into batches by the user.
Second, this puts unnecessary shape constraints on the data. Second, this puts unnecessary shape constraints on the data.
By letting each GPU train on its own input tensors, they can train on inputs of different shapes simultaneously. By letting each GPU train on its own input tensors, they can train on inputs of different shapes simultaneously.
2. The tower function (your model code) will get called multipile times. 2. The tower function (your model code) will get called multipile times on each GPU.
As a result, you'll need to be careful when modifying global states in those functions, e.g. adding ops to TF collections. You must follow the abovementieond rules of tower function.
### Distributed Trainers ### Distributed Trainers
...@@ -83,4 +98,4 @@ documentation of [HorovodTrainer](../modules/train.html#tensorpack.train.Horovod ...@@ -83,4 +98,4 @@ documentation of [HorovodTrainer](../modules/train.html#tensorpack.train.Horovod
Tensorpack has implemented some other distributed trainers using TF's native API, Tensorpack has implemented some other distributed trainers using TF's native API,
but TensorFlow is not actively supporting its distributed training features, and but TensorFlow is not actively supporting its distributed training features, and
its native distributed performance isn't very good even today. its native distributed performance isn't very good even today.
Therefore those trainers are not actively maintained and are not recommended for use. Therefore those trainers are not actively maintained and are __not recommended for use__.
...@@ -14,7 +14,7 @@ from ..utils.develop import HIDE_DOC ...@@ -14,7 +14,7 @@ from ..utils.develop import HIDE_DOC
from .collection import CollectionGuard from .collection import CollectionGuard
from .common import get_op_or_tensor_by_name, get_op_tensor_name from .common import get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper', __all__ = ['get_current_tower_context', 'BaseTowerContext', 'TowerContext', 'TowerFuncWrapper',
'TowerTensorHandle', 'TowerTensorHandles'] 'TowerTensorHandle', 'TowerTensorHandles']
_CurrentTowerContext = None _CurrentTowerContext = None
...@@ -23,9 +23,10 @@ _CurrentTowerContext = None ...@@ -23,9 +23,10 @@ _CurrentTowerContext = None
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class BaseTowerContext(object): class BaseTowerContext(object):
""" A context where the current model is built in. """ A context where the current model is built in.
Since TF 1.8, TensorFlow starts to introduce the same concept. You need to use :func:`TowerContext` to create a :class:`BaseTowerContext`.
""" """
@HIDE_DOC
def __init__(self, ns_name, vs_name=''): def __init__(self, ns_name, vs_name=''):
""" """
Args: Args:
...@@ -40,31 +41,49 @@ class BaseTowerContext(object): ...@@ -40,31 +41,49 @@ class BaseTowerContext(object):
@abstractproperty @abstractproperty
def is_main_training_tower(self): def is_main_training_tower(self):
"""
Whether this tower is the main (i.e., the first) training tower.
"""
pass pass
@abstractproperty @abstractproperty
def has_own_variables(self): def has_own_variables(self):
""" """
Whether this tower is supposed to have its own variables. Whether this tower is supposed to have its own trainable variables.
""" """
pass pass
@property @property
def name(self): def name(self):
"""
Returns:
str - The name scope of the tower.
"""
return self._name return self._name
@property @property
def vs_name(self): def vs_name(self):
"""
Returns:
str - The variable scope of the tower.
"""
return self._vs_name return self._vs_name
@property @property
def ns_name(self): def ns_name(self):
"""
Returns:
str - The name scope of the tower.
"""
return self._name return self._name
def get_collection_in_tower(self, key): def get_collection_in_tower(self, key):
""" """
Get items from this collection that are added in the current tower. From a collection, get items that are __added__ to the collection in this tower.
These items may or may not start with the same prefix as the tower.
Note that it works by tracking the collection at the beginning and end of
the tower function.
Therefore it does not guarantee that the items are __created__ in this tower.
""" """
return self._collection_guard.get_collection_in_tower(key) return self._collection_guard.get_collection_in_tower(key)
...@@ -194,15 +213,21 @@ class PredictTowerContext(BaseTowerContext): ...@@ -194,15 +213,21 @@ class PredictTowerContext(BaseTowerContext):
def get_current_tower_context(): def get_current_tower_context():
"""
When called inside a TowerContext, return the TowerContext.
Returns:
a :class:`BaseTowerContext` instance.
"""
assert _CurrentTowerContext is not None, "The function is supposed to be called under a TowerContext!"
return _CurrentTowerContext return _CurrentTowerContext
def TowerContext(tower_name, is_training, vs_name=''): def TowerContext(tower_name, is_training, vs_name=''):
""" """
User-facing API to build a tower manually. The context for a tower function, containing metadata about the current tower.
Tensorpack trainers use :class:`TowerContext` to manage tower function.
Returns: Many tensorpack layers have to be called under a :class:`TowerContext`.
A context within which the tower function should be called.
Example: Example:
...@@ -219,7 +244,8 @@ def TowerContext(tower_name, is_training, vs_name=''): ...@@ -219,7 +244,8 @@ def TowerContext(tower_name, is_training, vs_name=''):
class TowerFuncWrapper(object): class TowerFuncWrapper(object):
""" """
A wrapper around a tower function (function which builds one tower, i.e. one replicate of the model). A wrapper around a tower function (see
[tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer)).
It keeps track of the name scope, variable scope and input/output tensors It keeps track of the name scope, variable scope and input/output tensors
each time the function is called. each time the function is called.
...@@ -231,7 +257,8 @@ class TowerFuncWrapper(object): ...@@ -231,7 +257,8 @@ class TowerFuncWrapper(object):
Args: Args:
tower_func: a function which builds one tower in the graph. tower_func: a function which builds one tower in the graph.
It takes several input tensors and could return anything. It takes several input tensors and could return anything.
inputs_desc ([InputDesc]): use this to figure out the right name for the input tensors. inputs_desc ([InputDesc]): list of :class:`InputDesc`.
They are used to figure out the names for the input tensors.
""" """
assert callable(tower_fn), tower_fn assert callable(tower_fn), tower_fn
inputs_desc_names = [k.name for k in inputs_desc] inputs_desc_names = [k.name for k in inputs_desc]
...@@ -344,7 +371,9 @@ class TowerTensorHandle(object): ...@@ -344,7 +371,9 @@ class TowerTensorHandle(object):
def get_tensor(self, name): def get_tensor(self, name):
""" """
Get a tensor in this tower. The name can be: Get a tensor in this tower. The name can be:
1. The name of the tensor without any tower prefix. 1. The name of the tensor without any tower prefix.
2. The name of an :class:`InputDesc`, if it is used when building the tower. 2. The name of an :class:`InputDesc`, if it is used when building the tower.
""" """
name = get_op_tensor_name(name)[1] name = get_op_tensor_name(name)[1]
...@@ -368,14 +397,24 @@ class TowerTensorHandle(object): ...@@ -368,14 +397,24 @@ class TowerTensorHandle(object):
return ret return ret
def get_tensors(self, names): def get_tensors(self, names):
"""
Like :meth:`get_tensor`, but takes a list and returns a list.
"""
return [self.get_tensor(name) for name in names] return [self.get_tensor(name) for name in names]
def __getitem__(self, name): def __getitem__(self, name):
"""
The same as :meth:`get_tensor`.
"""
return self.get_tensor(name) return self.get_tensor(name)
def get_variable(self, name): def get_variable(self, name):
""" """
Get a variable used in this tower. Get a variable used in this tower.
The name should not contain the variable scope prefix of the tower.
When the tower has the same variable scope and name scope, this is equivalent to
:meth:`get_tensor`.
""" """
name = get_op_tensor_name(name)[1] name = get_op_tensor_name(name)[1]
if len(self.vs_name): if len(self.vs_name):
...@@ -384,15 +423,24 @@ class TowerTensorHandle(object): ...@@ -384,15 +423,24 @@ class TowerTensorHandle(object):
name_with_vs = name name_with_vs = name
return get_op_or_tensor_by_name(name_with_vs) return get_op_or_tensor_by_name(name_with_vs)
def get_collection(self, name): def get_variables(self, names):
"""
Like :meth:`get_variable`, but takes a list and returns a list.
"""
return [self.get_variable(name) for name in names]
def get_collection(self, key=None, name=None):
""" """
Get items from a collection that are added in this tower. See :meth:`BaseTowerContext.get_collection_in_tower`.
These items may or may not start with the same prefix as the tower.
Args: Args:
name (str): the name of the collection key (str): the key of the collection
name: deprecated
""" """
return self._ctx.get_collection_in_tower(name) if name is not None:
logger.warn("TowerTensorHandle.get_collection(name=..) was renamed to (key=..) !")
key = name
return self._ctx.get_collection_in_tower(key)
@property @property
def input(self): def input(self):
......
...@@ -46,7 +46,8 @@ class TowerTrainer(Trainer): ...@@ -46,7 +46,8 @@ class TowerTrainer(Trainer):
def tower_func(self): def tower_func(self):
""" """
A :class:`TowerFuncWrapper` instance. A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model. See [tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer)
for more information.
""" """
return self._tower_func return self._tower_func
...@@ -70,6 +71,14 @@ class TowerTrainer(Trainer): ...@@ -70,6 +71,14 @@ class TowerTrainer(Trainer):
access the tower handles by either indices or names. access the tower handles by either indices or names.
It is accessbile only after the graph is set up. It is accessbile only after the graph is set up.
With :meth:`towers`, you can then access many attributes of each tower:
Example:
.. code-block:: python
# Access the conv1/output tensor in the first training tower
trainer.towers.training()[0].get_tensor('conv1/output')
""" """
return self.tower_func.towers return self.tower_func.towers
...@@ -92,9 +101,9 @@ class TowerTrainer(Trainer): ...@@ -92,9 +101,9 @@ class TowerTrainer(Trainer):
# in the graph: # in the graph:
interesting_tensor = tf.identity(x, name='fun') interesting_tensor = tf.identity(x, name='fun')
# in _setup_graph callback method: # in _setup_graph callback method:
self._predictor = self.trainer.get_predictor(['input1'], ['fun']) self._predictor = self.trainer.get_predictor(['input1', 'input2'], ['fun'])
# After session is initialized (see Tutorials - Write a Callback), can use it by: # After session is initialized (see Tutorials - Write a Callback), can use it by:
outputs = self._predictor(inputs) outputs = self._predictor(input1, input2)
The CycleGAN example and DQN example have more concrete use of this method. The CycleGAN example and DQN example have more concrete use of this method.
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment