Commit cbb26847 authored by Yuxin Wu's avatar Yuxin Wu

comments for common trainers

parent 92ee69dc
...@@ -49,7 +49,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase): ...@@ -49,7 +49,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
def __init__(self, config, server): def __init__(self, config, server):
""" """
Args: Args:
config (TrainConfig): the train config. config(TrainConfig): Must contain 'model' and 'data'.
server (tf.train.Server): the server object with ps and workers server (tf.train.Server): the server object with ps and workers
""" """
assert config.data is not None and config.model is not None assert config.data is not None and config.model is not None
......
...@@ -56,7 +56,7 @@ def QueueInputTrainer(config, input_queue=None): ...@@ -56,7 +56,7 @@ def QueueInputTrainer(config, input_queue=None):
It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``. It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args: Args:
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist. config (TrainConfig): Must contain 'model' and 'dataflow'.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default. input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
""" """
assert (config.data is not None or config.dataflow is not None) and config.model is not None assert (config.data is not None or config.dataflow is not None) and config.model is not None
......
...@@ -150,15 +150,15 @@ class LeastLoadedDeviceSetter(object): ...@@ -150,15 +150,15 @@ class LeastLoadedDeviceSetter(object):
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase): class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
""" """
A data-parallel Multi-GPU trainer which synchronoizes the gradients computed A data-parallel multi-GPU trainer. It builds one tower on each GPU with
from each tower, averages them and update to variables stored across all shared variable scope. It synchronoizes the gradients computed
GPUs or on CPU. from each tower, averages them and applies to the shared variables.
""" """
def __init__(self, config, ps_device='gpu', gpu_prefetch=True): def __init__(self, config, ps_device='gpu', gpu_prefetch=True):
""" """
Args: Args:
config(TrainConfig): config(TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
ps_device: either 'gpu' or 'cpu', where variables are stored. ps_device: either 'gpu' or 'cpu', where variables are stored.
gpu_prefetch(bool): whether to prefetch the data to each GPU. Usually improve performance. gpu_prefetch(bool): whether to prefetch the data to each GPU. Usually improve performance.
""" """
...@@ -199,6 +199,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase): ...@@ -199,6 +199,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
Returns: Returns:
tf.Operation: the training op tf.Operation: the training op
[Callback]: the callbacks to be added [Callback]: the callbacks to be added
""" """
input.setup(model.get_inputs_desc()) input.setup(model.get_inputs_desc())
...@@ -244,8 +245,9 @@ def SyncMultiGPUTrainer(config): ...@@ -244,8 +245,9 @@ def SyncMultiGPUTrainer(config):
class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase): class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
""" """
Data-parallel Multi-GPU trainer where each GPU contains a replicate of the Data-parallel multi-GPU trainer where each GPU contains a replicate of the whole model.
whole model. Each gradient update is broadcast and synced. It will build one tower on each GPU under its own variable scope.
Each gradient update is averaged across or GPUs through NCCL.
""" """
def __init__(self, config, gpu_prefetch=True): def __init__(self, config, gpu_prefetch=True):
""" """
...@@ -289,6 +291,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase): ...@@ -289,6 +291,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
Returns: Returns:
tf.Operation: the training op tf.Operation: the training op
[Callback]: the callbacks to be added [Callback]: the callbacks to be added
""" """
input.setup(model.get_inputs_desc()) input.setup(model.get_inputs_desc())
...@@ -346,14 +349,15 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase): ...@@ -346,14 +349,15 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
class AsyncMultiGPUTrainer(MultiGPUTrainerBase): class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
""" """
A multi-tower multi-GPU trainer where each tower independently A data-parallel multi-GPU trainer. It builds one tower on each GPU with shared variable scope.
asynchronously updates the model without averaging the gradient. Every tower computes the gradients and independently applies them to the
variables, without synchronizing and averaging across towers.
""" """
def __init__(self, config, scale_gradient=True): def __init__(self, config, scale_gradient=True):
""" """
Args: Args:
config(TrainConfig): config(TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``. scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
""" """
apply_prefetch_policy(config) apply_prefetch_policy(config)
...@@ -372,6 +376,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase): ...@@ -372,6 +376,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainerBase):
Returns: Returns:
tf.Operation: the training op tf.Operation: the training op
[Callback]: the callbacks to be added [Callback]: the callbacks to be added
""" """
input.setup(model.get_inputs_desc()) input.setup(model.get_inputs_desc())
......
...@@ -14,14 +14,17 @@ __all__ = ['SimpleTrainer'] ...@@ -14,14 +14,17 @@ __all__ = ['SimpleTrainer']
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
""" A naive single-tower single-cost demo trainer. """ A naive single-tower single-cost demo trainer.
Support both InputSource and DataFlow. It simply builds one tower and minimize `model.cost`.
When DataFlow is given instead of InputSource, the InputSource to be used will be ``FeedInput(df)``. It supports both InputSource and DataFlow.
When DataFlow is given instead of InputSource, the InputSource to be
used will be ``FeedInput(df)`` (no prefetch).
""" """
def __init__(self, config): def __init__(self, config):
""" """
Args: Args:
config (TrainConfig): the training config. config (TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
""" """
assert len(config.tower) == 1, \ assert len(config.tower) == 1, \
"Got nr_tower={}, but doesn't support multigpu!" \ "Got nr_tower={}, but doesn't support multigpu!" \
...@@ -39,7 +42,7 @@ class SimpleTrainer(Trainer): ...@@ -39,7 +42,7 @@ class SimpleTrainer(Trainer):
@staticmethod @staticmethod
def setup_graph(model, input): def setup_graph(model, input):
""" """
Setup graph for simple trainer. Setup graph for SimpleTrainer. It simply build one tower and optimize `model.cost`.
Args: Args:
model (ModelDesc): model (ModelDesc):
...@@ -47,6 +50,7 @@ class SimpleTrainer(Trainer): ...@@ -47,6 +50,7 @@ class SimpleTrainer(Trainer):
Returns: Returns:
tf.Operation: the training op tf.Operation: the training op
[Callback]: the callbacks to be added [Callback]: the callbacks to be added
""" """
input.setup(model.get_inputs_desc()) input.setup(model.get_inputs_desc())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment