Commit cbb26847 authored by Yuxin Wu's avatar Yuxin Wu

comments for common trainers

parent 92ee69dc
......@@ -49,7 +49,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
def __init__(self, config, server):
"""
Args:
config (TrainConfig): the train config.
config(TrainConfig): Must contain 'model' and 'data'.
server (tf.train.Server): the server object with ps and workers
"""
assert config.data is not None and config.model is not None
......
......@@ -56,7 +56,7 @@ def QueueInputTrainer(config, input_queue=None):
It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args:
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
config (TrainConfig): Must contain 'model' and 'dataflow'.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
"""
assert (config.data is not None or config.dataflow is not None) and config.model is not None
......
......@@ -150,15 +150,15 @@ class LeastLoadedDeviceSetter(object):
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
"""
A data-parallel Multi-GPU trainer which synchronoizes the gradients computed
from each tower, averages them and update to variables stored across all
GPUs or on CPU.
A data-parallel multi-GPU trainer. It builds one tower on each GPU with
shared variable scope. It synchronoizes the gradients computed
from each tower, averages them and applies to the shared variables.
"""
def __init__(self, config, ps_device='gpu', gpu_prefetch=True):
"""
Args:
config(TrainConfig):
config(TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
ps_device: either 'gpu' or 'cpu', where variables are stored.
gpu_prefetch(bool): whether to prefetch the data to each GPU. Usually improve performance.
"""
......@@ -199,6 +199,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
Returns:
tf.Operation: the training op
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
......@@ -244,8 +245,9 @@ def SyncMultiGPUTrainer(config):
class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
"""
Data-parallel Multi-GPU trainer where each GPU contains a replicate of the
whole model. Each gradient update is broadcast and synced.
Data-parallel multi-GPU trainer where each GPU contains a replicate of the whole model.
It will build one tower on each GPU under its own variable scope.
Each gradient update is averaged across or GPUs through NCCL.
"""
def __init__(self, config, gpu_prefetch=True):
"""
......@@ -289,6 +291,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
Returns:
tf.Operation: the training op
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
......@@ -346,14 +349,15 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
"""
A multi-tower multi-GPU trainer where each tower independently
asynchronously updates the model without averaging the gradient.
A data-parallel multi-GPU trainer. It builds one tower on each GPU with shared variable scope.
Every tower computes the gradients and independently applies them to the
variables, without synchronizing and averaging across towers.
"""
def __init__(self, config, scale_gradient=True):
"""
Args:
config(TrainConfig):
config(TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
"""
apply_prefetch_policy(config)
......@@ -372,6 +376,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
Returns:
tf.Operation: the training op
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
......
......@@ -14,14 +14,17 @@ __all__ = ['SimpleTrainer']
class SimpleTrainer(Trainer):
""" A naive single-tower single-cost demo trainer.
Support both InputSource and DataFlow.
When DataFlow is given instead of InputSource, the InputSource to be used will be ``FeedInput(df)``.
It simply builds one tower and minimize `model.cost`.
It supports both InputSource and DataFlow.
When DataFlow is given instead of InputSource, the InputSource to be
used will be ``FeedInput(df)`` (no prefetch).
"""
def __init__(self, config):
"""
Args:
config (TrainConfig): the training config.
config (TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
"""
assert len(config.tower) == 1, \
"Got nr_tower={}, but doesn't support multigpu!" \
......@@ -39,7 +42,7 @@ class SimpleTrainer(Trainer):
@staticmethod
def setup_graph(model, input):
"""
Setup graph for simple trainer.
Setup graph for SimpleTrainer. It simply build one tower and optimize `model.cost`.
Args:
model (ModelDesc):
......@@ -47,6 +50,7 @@ class SimpleTrainer(Trainer):
Returns:
tf.Operation: the training op
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment