Commit 92f159b6 authored by Yuxin Wu's avatar Yuxin Wu

update docs on trainer

parent e78e2e1e
PLEASE finish reading to show some respect to the authors.
An issue has to be one of the following:
- [ ] Unexpected Problems / Potential Bugs
- [ ] Feature Requests
- [ ] Questions on Using/Understanding Tensorpack
- Unexpected Problems / Potential Bugs
- Feature Requests
- Questions on Using/Understanding Tensorpack
## For any unexpected problems, __PLEASE ALWAYS INCLUDE__:
1. What you did:
......@@ -28,7 +30,7 @@ About efficiency issues, PLEASE first read http://tensorpack.readthedocs.io/en/l
(See http://tensorpack.readthedocs.io/en/latest/tutorial/index.html#extend-tensorpack).
It does not have to be added to Tensorpack unless you have a good reason.
+ "Could you improve/implement an example/paper ?"
-- the answer is: we have no plans to do so. We don't take feature requests for
-- The answer is: we have no plans to do so. We don't take feature requests for
examples or implement a paper for you. If you don't know how to do it, you may ask a usage question.
## Usage Questions:
......@@ -36,7 +38,7 @@ About efficiency issues, PLEASE first read http://tensorpack.readthedocs.io/en/l
+ Read the [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html#user-tutorials) first.
+ We answer "HOW to do X with Tensorpack" for a well-defined X.
We also answer "HOW/WHY Tensorpack does X" for some X that Tensorpack or its examples are doing.
We don't answer general machine learning questions,
such as "why my training doesn't converge", "what networks to use" or "I don't understand the paper".
We don't answer general machine learning questions, such as "why my training doesn't converge", "what networks to use" or "I don't understand the paper".
You can also use gitter (https://gitter.im/tensorpack/users) for more casual discussions.
......@@ -13,43 +13,58 @@ But some basic knowledge of how they work is useful:
### Tower Trainer
[TowerTrainer](../modules/train.html#tensorpack.train.TowerTrainer)
is a trainer that uses "tower function" to build models.
All existing trainers in tensorpack are subclass of ``TowerTrainer``,
because this concept is able to cover most types of neural-network training tasks.
#### What is Tower Function
Following the terminology in TensorFlow,
a __tower function__ is a callable that takes input tensors and adds __one replicate__ of the model to the graph.
Most types of neural-network training could be described with this concept.
The concept of tower is used mainly to support:
The concept of tower is used mainly to support:
1. Data-parallel multi-GPU training, where a replicate is built on each GPU.
2. Graph construction for inference, where a replicate is built under inference mode.
A user needs to provide a tower function to use `TowerTrainer`.
In particular, when working with the `ModelDesc` interface, the `build_graph` method will be the tower function.
In particular, when working with the `ModelDesc` interface, the `build_graph`
method will be part of the tower function.
The tower function needs to follow some conventions:
#### Rules of Tower Function
1. __It might get called multiple times__ for data-parallel training or inference.
* Therefore, to use a tensorflow-hub module, you need to initialize the
The tower function needs to follow some rules:
1. __It may get called multiple times__ for data-parallel training or inference. As a result:
* You'll need to be careful when modifying global states, e.g.
adding ops to collections, setting attributes of a model instance.
* To use a tensorflow-hub module, you need to initialize the
module outside the tower function, and call the module inside the tower function.
2. It has to respect variable collections:
2. It must __respect variable collections__:
* (Required) Only put variables __trainable by gradient descent__ into `TRAINABLE_VARIABLES`.
* (Recommended) Put non-trainable variables that need to be used in inference into `MODEL_VARIABLES`.
3. It has to respect variable scopes:
* The name of any trainable variables created in the function must be like "variable_scope_name/custom/name".
3. It must __respect variable scopes__:
* The name of any trainable variables created in the function must be like "variable_scope_name/custom/scopes/name".
Don't depend on name_scope's name. Don't use variable_scope's name twice.
* The creation of any trainable variables must respect __reuse__ variable scope.
To respect variable reuse, use `tf.get_variable` instead of `tf.Variable` in the function.
On the other hand, for non-trainable variables, it's OK to use
`tf.Variable` to ensure creation of new variables in each tower even when `reuse=True`.
4. It will always be called under a `TowerContext`, which can be accessed by `get_current_tower_context()`.
The context contains information about training/inference mode, reuse, etc.
5. It cannot create scopes or variables containing the name 'tower', as it is
* The creation of any trainable variables must __respect reuse__ variable scope.
To respect variable reuse (i.e. sharing), use `tf.get_variable` instead of `tf.Variable` in the function.
On the other hand, for a non-trainable variable, it may be desirable to not reuse it between towers.
In this case, `tf.Variable` can be used to ensure creation of new variables in each tower even when `reuse=True`.
4. It cannot create scopes or variables containing the name 'tower', as it is
reserved for special use.
These conventions are easy to follow, and most layer wrappers (e.g.,
tf.layers/slim/tensorlayer) do follow them. Note that certain Keras layers do not
follow these conventions and will need some workarounds if used within tensorpack.
It's possible to write ones that are not, but all existing trainers in
tensorpack are subclass of [TowerTrainer](../modules/train.html#tensorpack.train.TowerTrainer).
#### What You Can Do Inside Tower Function
1. Call any symbolic functions as long as they follow the above rules.
2. The function will be called under a
[TowerContext](../modules/tfutils.html#tensorpack.tfutils.tower.BaseTowerContext),
which can be accessed by [get_current_tower_context()](../modules/tfutils.html#tensorpack.tfutils.tower.get_current_tower_context).
The context contains information about training/inference mode, scope name, etc.
### MultiGPU Trainers
......@@ -62,17 +77,17 @@ It takes only one line of code change to use them, e.g. `trainer=SyncMultiGPUTra
Note some __common problems__ when using these trainers:
1. In each iteration, all GPUs (all replicates of the model) take tensors from the `InputSource`,
instead of taking one for all and split.
1. In each iteration, instead of taking one tensor for all GPUs and split,
all GPUs take tensors from the `InputSource`.
So the total batch size would become ``(batch size of InputSource) * #GPU``.
Splitting a tensor for data-parallel training makes no sense at all. First, why
wasting time in concatenating into large batches and then split them?
Splitting a tensor for data-parallel training makes no sense at all. First,
it wastes time because typically data is concatenated into batches by the user.
Second, this puts unnecessary shape constraints on the data.
By letting each GPU train on its own input tensors, they can train on inputs of different shapes simultaneously.
2. The tower function (your model code) will get called multipile times.
As a result, you'll need to be careful when modifying global states in those functions, e.g. adding ops to TF collections.
2. The tower function (your model code) will get called multipile times on each GPU.
You must follow the abovementieond rules of tower function.
### Distributed Trainers
......@@ -83,4 +98,4 @@ documentation of [HorovodTrainer](../modules/train.html#tensorpack.train.Horovod
Tensorpack has implemented some other distributed trainers using TF's native API,
but TensorFlow is not actively supporting its distributed training features, and
its native distributed performance isn't very good even today.
Therefore those trainers are not actively maintained and are not recommended for use.
Therefore those trainers are not actively maintained and are __not recommended for use__.
......@@ -14,7 +14,7 @@ from ..utils.develop import HIDE_DOC
from .collection import CollectionGuard
from .common import get_op_or_tensor_by_name, get_op_tensor_name
__all__ = ['get_current_tower_context', 'TowerContext', 'TowerFuncWrapper',
__all__ = ['get_current_tower_context', 'BaseTowerContext', 'TowerContext', 'TowerFuncWrapper',
'TowerTensorHandle', 'TowerTensorHandles']
_CurrentTowerContext = None
......@@ -23,9 +23,10 @@ _CurrentTowerContext = None
@six.add_metaclass(ABCMeta)
class BaseTowerContext(object):
""" A context where the current model is built in.
Since TF 1.8, TensorFlow starts to introduce the same concept.
You need to use :func:`TowerContext` to create a :class:`BaseTowerContext`.
"""
@HIDE_DOC
def __init__(self, ns_name, vs_name=''):
"""
Args:
......@@ -40,31 +41,49 @@ class BaseTowerContext(object):
@abstractproperty
def is_main_training_tower(self):
"""
Whether this tower is the main (i.e., the first) training tower.
"""
pass
@abstractproperty
def has_own_variables(self):
"""
Whether this tower is supposed to have its own variables.
Whether this tower is supposed to have its own trainable variables.
"""
pass
@property
def name(self):
"""
Returns:
str - The name scope of the tower.
"""
return self._name
@property
def vs_name(self):
"""
Returns:
str - The variable scope of the tower.
"""
return self._vs_name
@property
def ns_name(self):
"""
Returns:
str - The name scope of the tower.
"""
return self._name
def get_collection_in_tower(self, key):
"""
Get items from this collection that are added in the current tower.
These items may or may not start with the same prefix as the tower.
From a collection, get items that are __added__ to the collection in this tower.
Note that it works by tracking the collection at the beginning and end of
the tower function.
Therefore it does not guarantee that the items are __created__ in this tower.
"""
return self._collection_guard.get_collection_in_tower(key)
......@@ -194,15 +213,21 @@ class PredictTowerContext(BaseTowerContext):
def get_current_tower_context():
"""
When called inside a TowerContext, return the TowerContext.
Returns:
a :class:`BaseTowerContext` instance.
"""
assert _CurrentTowerContext is not None, "The function is supposed to be called under a TowerContext!"
return _CurrentTowerContext
def TowerContext(tower_name, is_training, vs_name=''):
"""
User-facing API to build a tower manually.
Returns:
A context within which the tower function should be called.
The context for a tower function, containing metadata about the current tower.
Tensorpack trainers use :class:`TowerContext` to manage tower function.
Many tensorpack layers have to be called under a :class:`TowerContext`.
Example:
......@@ -219,7 +244,8 @@ def TowerContext(tower_name, is_training, vs_name=''):
class TowerFuncWrapper(object):
"""
A wrapper around a tower function (function which builds one tower, i.e. one replicate of the model).
A wrapper around a tower function (see
[tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer)).
It keeps track of the name scope, variable scope and input/output tensors
each time the function is called.
......@@ -231,7 +257,8 @@ class TowerFuncWrapper(object):
Args:
tower_func: a function which builds one tower in the graph.
It takes several input tensors and could return anything.
inputs_desc ([InputDesc]): use this to figure out the right name for the input tensors.
inputs_desc ([InputDesc]): list of :class:`InputDesc`.
They are used to figure out the names for the input tensors.
"""
assert callable(tower_fn), tower_fn
inputs_desc_names = [k.name for k in inputs_desc]
......@@ -344,7 +371,9 @@ class TowerTensorHandle(object):
def get_tensor(self, name):
"""
Get a tensor in this tower. The name can be:
1. The name of the tensor without any tower prefix.
2. The name of an :class:`InputDesc`, if it is used when building the tower.
"""
name = get_op_tensor_name(name)[1]
......@@ -368,14 +397,24 @@ class TowerTensorHandle(object):
return ret
def get_tensors(self, names):
"""
Like :meth:`get_tensor`, but takes a list and returns a list.
"""
return [self.get_tensor(name) for name in names]
def __getitem__(self, name):
"""
The same as :meth:`get_tensor`.
"""
return self.get_tensor(name)
def get_variable(self, name):
"""
Get a variable used in this tower.
The name should not contain the variable scope prefix of the tower.
When the tower has the same variable scope and name scope, this is equivalent to
:meth:`get_tensor`.
"""
name = get_op_tensor_name(name)[1]
if len(self.vs_name):
......@@ -384,15 +423,24 @@ class TowerTensorHandle(object):
name_with_vs = name
return get_op_or_tensor_by_name(name_with_vs)
def get_collection(self, name):
def get_variables(self, names):
"""
Like :meth:`get_variable`, but takes a list and returns a list.
"""
return [self.get_variable(name) for name in names]
def get_collection(self, key=None, name=None):
"""
Get items from a collection that are added in this tower.
These items may or may not start with the same prefix as the tower.
See :meth:`BaseTowerContext.get_collection_in_tower`.
Args:
name (str): the name of the collection
key (str): the key of the collection
name: deprecated
"""
return self._ctx.get_collection_in_tower(name)
if name is not None:
logger.warn("TowerTensorHandle.get_collection(name=..) was renamed to (key=..) !")
key = name
return self._ctx.get_collection_in_tower(key)
@property
def input(self):
......
......@@ -46,7 +46,8 @@ class TowerTrainer(Trainer):
def tower_func(self):
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
See [tutorial on tower function](http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer)
for more information.
"""
return self._tower_func
......@@ -70,6 +71,14 @@ class TowerTrainer(Trainer):
access the tower handles by either indices or names.
It is accessbile only after the graph is set up.
With :meth:`towers`, you can then access many attributes of each tower:
Example:
.. code-block:: python
# Access the conv1/output tensor in the first training tower
trainer.towers.training()[0].get_tensor('conv1/output')
"""
return self.tower_func.towers
......@@ -92,9 +101,9 @@ class TowerTrainer(Trainer):
# in the graph:
interesting_tensor = tf.identity(x, name='fun')
# in _setup_graph callback method:
self._predictor = self.trainer.get_predictor(['input1'], ['fun'])
self._predictor = self.trainer.get_predictor(['input1', 'input2'], ['fun'])
# After session is initialized (see Tutorials - Write a Callback), can use it by:
outputs = self._predictor(inputs)
outputs = self._predictor(input1, input2)
The CycleGAN example and DQN example have more concrete use of this method.
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment