Commit 6efe0deb authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent c2cec01e
...@@ -23,7 +23,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder): ...@@ -23,7 +23,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
and get synchronously applied to the global copy of variables located on PS. and get synchronously applied to the global copy of variables located on PS.
Then each worker copy the latest variables from PS back to local. Then each worker copy the latest variables from PS back to local.
It is an equivalent of `--variable_update=distributed_replicated` in It is an equivalent of ``--variable_update=distributed_replicated`` in
`tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_. `tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_.
Note: Note:
......
...@@ -106,7 +106,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): ...@@ -106,7 +106,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
shared variable scope. It synchronoizes the gradients computed shared variable scope. It synchronoizes the gradients computed
from each tower, averages them and applies to the shared variables. from each tower, averages them and applies to the shared variables.
It is an equivalent of `--variable_update=parameter_server` in It is an equivalent of ``--variable_update=parameter_server`` in
`tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_. `tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_.
""" """
def __init__(self, towers, ps_device=None): def __init__(self, towers, ps_device=None):
...@@ -165,7 +165,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -165,7 +165,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
It will build one tower on each GPU under its own variable scope. It will build one tower on each GPU under its own variable scope.
Each gradient update is averaged across or GPUs through NCCL. Each gradient update is averaged across or GPUs through NCCL.
It is an equivalent of `--variable_update=replicated` in It is an equivalent of ``--variable_update=replicated`` in
`tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_. `tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_.
""" """
......
...@@ -71,13 +71,16 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer): ...@@ -71,13 +71,16 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
""" """
@map_arg(gpus=_int_to_range) @map_arg(gpus=_int_to_range)
def __init__(self, gpus, ps_device='gpu'): def __init__(self, gpus, ps_device=None):
""" """
Args: Args:
gpus ([int]): list of GPU ids. gpus ([int]): list of GPU ids.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4 ps_device: either 'gpu' or 'cpu', where variables are stored.
The default value is subject to change.
""" """
self.devices = gpus self.devices = gpus
if ps_device is None:
ps_device = 'gpu' if len(gpus) <= 2 else 'cpu'
self._builder = SyncMultiGPUParameterServerBuilder(gpus, ps_device) self._builder = SyncMultiGPUParameterServerBuilder(gpus, ps_device)
super(SyncMultiGPUTrainerParameterServer, self).__init__() super(SyncMultiGPUTrainerParameterServer, self).__init__()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment