Commit 6efe0deb authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent c2cec01e
......@@ -23,7 +23,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
and get synchronously applied to the global copy of variables located on PS.
Then each worker copy the latest variables from PS back to local.
It is an equivalent of `--variable_update=distributed_replicated` in
It is an equivalent of ``--variable_update=distributed_replicated`` in
`tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_.
Note:
......
......@@ -106,7 +106,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
shared variable scope. It synchronoizes the gradients computed
from each tower, averages them and applies to the shared variables.
It is an equivalent of `--variable_update=parameter_server` in
It is an equivalent of ``--variable_update=parameter_server`` in
`tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_.
"""
def __init__(self, towers, ps_device=None):
......@@ -165,7 +165,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
It will build one tower on each GPU under its own variable scope.
Each gradient update is averaged across or GPUs through NCCL.
It is an equivalent of `--variable_update=replicated` in
It is an equivalent of ``--variable_update=replicated`` in
`tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_.
"""
......
......@@ -71,13 +71,16 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
"""
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, ps_device='gpu'):
def __init__(self, gpus, ps_device=None):
"""
Args:
gpus ([int]): list of GPU ids.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4
ps_device: either 'gpu' or 'cpu', where variables are stored.
The default value is subject to change.
"""
self.devices = gpus
if ps_device is None:
ps_device = 'gpu' if len(gpus) <= 2 else 'cpu'
self._builder = SyncMultiGPUParameterServerBuilder(gpus, ps_device)
super(SyncMultiGPUTrainerParameterServer, self).__init__()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment