Commit 5c25afcb authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent 4692e325
...@@ -62,8 +62,28 @@ class DistributedBuilderBase(GraphBuilder): ...@@ -62,8 +62,28 @@ class DistributedBuilderBase(GraphBuilder):
class DistributedParameterServerBuilder(DataParallelBuilder, DistributedBuilderBase): class DistributedParameterServerBuilder(DataParallelBuilder, DistributedBuilderBase):
"""
Distributed parameter server training.
A single copy of parameters are scattered around PS.
Gradients across GPUs are averaged within the worker, and applied to PS.
Each worker also caches the variables for reading.
It is an equivalent of ``--variable_update=parameter_server`` in
`tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_.
Note:
1. Gradients are not averaged across workers, but applied to PS variables
directly (either with or without locking depending on the optimizer).
"""
def __init__(self, towers, server, caching_device): def __init__(self, towers, server, caching_device):
"""
Args:
towers (list[int]): list of GPU ids.
server (tf.train.Server): the server with ps and workers.
job_name must be 'worker'.
caching_device (str): either 'cpu' or 'gpu'
"""
DataParallelBuilder.__init__(self, towers) DataParallelBuilder.__init__(self, towers)
DistributedBuilderBase.__init__(self, server) DistributedBuilderBase.__init__(self, server)
...@@ -120,9 +140,13 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase): ...@@ -120,9 +140,13 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
`tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_. `tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_.
Note: Note:
Gradients are not averaged across workers, but applied to PS variables 1. Gradients are not averaged across workers, but applied to PS variables
directly (either with or without locking depending on the optimizer). directly (either with or without locking depending on the optimizer).
2. Some details about collections: all variables created inside tower
will become local variables,
and a clone will be made in global variables for all trainable/model variables.
Example: Example:
.. code-block:: python .. code-block:: python
...@@ -142,9 +166,9 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase): ...@@ -142,9 +166,9 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
# Start training like this: # Start training like this:
(host1)$ train.py --job worker --task 0 (host1)$ train.py --job worker --task 0
(host1)$ train.py --job ps --task 0 (host1)$ CUDA_VISIBLE_DEVICES= train.py --job ps --task 0
(host2)$ train.py --job worker --task 1 (host2)$ train.py --job worker --task 1
(host2)$ train.py --job ps --task 1 (host2)$ CUDA_VISIBLE_DEVICES= train.py --job ps --task 1
""" """
def __init__(self, towers, server): def __init__(self, towers, server):
...@@ -152,8 +176,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase): ...@@ -152,8 +176,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
Args: Args:
towers (list[int]): list of GPU ids. towers (list[int]): list of GPU ids.
server (tf.train.Server): the server with ps and workers. server (tf.train.Server): the server with ps and workers.
The job_name must be 'worker' because 'ps' job doesn't need to job_name must be 'worker'.
build any graph.
""" """
DataParallelBuilder.__init__(self, towers) DataParallelBuilder.__init__(self, towers)
DistributedBuilderBase.__init__(self, server) DistributedBuilderBase.__init__(self, server)
......
...@@ -160,7 +160,6 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer): ...@@ -160,7 +160,6 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
class DistributedTrainerBase(SingleCostTrainer): class DistributedTrainerBase(SingleCostTrainer):
devices = None devices = None
# TODO use full device name instead of id
def __init__(self, gpus, server): def __init__(self, gpus, server):
super(DistributedTrainerBase, self).__init__() super(DistributedTrainerBase, self).__init__()
...@@ -195,6 +194,8 @@ class DistributedTrainerParameterServer(DistributedTrainerBase): ...@@ -195,6 +194,8 @@ class DistributedTrainerParameterServer(DistributedTrainerBase):
""" """
Args: Args:
gpus ([int]): list of GPU ids. gpus ([int]): list of GPU ids.
server (tf.train.Server): the server with ps and workers.
caching_device (str): either 'cpu' or 'gpu'. The device to cache variables copied from PS
""" """
super(DistributedTrainerParameterServer, self).__init__(gpus, server) super(DistributedTrainerParameterServer, self).__init__(gpus, server)
assert self.job_name in ['ps', 'worker'], self.job_name assert self.job_name in ['ps', 'worker'], self.job_name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment