Commit 82187086 authored by Yuxin Wu's avatar Yuxin Wu

Share documents between builder & trainer

parent 5b8ed8be
...@@ -17,7 +17,7 @@ __all__ = ['DistributedReplicatedBuilder'] ...@@ -17,7 +17,7 @@ __all__ = ['DistributedReplicatedBuilder']
class DistributedReplicatedBuilder(DataParallelBuilder): class DistributedReplicatedBuilder(DataParallelBuilder):
""" """
Graph builder for distributed replicated training. Distributed replicated training.
Each worker process builds the same model on one or more GPUs. Each worker process builds the same model on one or more GPUs.
Gradients across GPUs are averaged within the worker, Gradients across GPUs are averaged within the worker,
and get synchronously applied to the global copy of variables located on PS. and get synchronously applied to the global copy of variables located on PS.
...@@ -28,6 +28,28 @@ class DistributedReplicatedBuilder(DataParallelBuilder): ...@@ -28,6 +28,28 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
Note: Note:
Gradients are not averaged across workers, but applied to PS variables Gradients are not averaged across workers, but applied to PS variables
directly (either with or without locking depending on the optimizer). directly (either with or without locking depending on the optimizer).
Example:
.. code-block:: python
# Create the server object like this:
hosts = ['host1.com', 'host2.com']
cluster_spec = tf.train.ClusterSpec({
'ps': [h + ':2222' for h in hosts],
'worker': [h + ':2223' for h in hosts]
})
server = tf.train.Server(
cluster_spec, job_name=args.job, task_index=args.task,
config=get_default_sess_config())
.. code-block:: none
# Start training like this:
(host1)$ train.py --job worker --task 0
(host1)$ train.py --job ps --task 0
(host2)$ train.py --job worker --task 1
(host2)$ train.py --job ps --task 1
""" """
def __init__(self, towers, server): def __init__(self, towers, server):
......
...@@ -15,6 +15,8 @@ from ..tfutils.common import get_tf_version_number ...@@ -15,6 +15,8 @@ from ..tfutils.common import get_tf_version_number
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
from ..tfutils.gradproc import ScaleGradient from ..tfutils.gradproc import ScaleGradient
from ..utils.naming import TOWER_FREEZE_KEYS from ..utils.naming import TOWER_FREEZE_KEYS
from ..input_source import FeedfreeInput
from .utils import LeastLoadedDeviceSetter, override_to_local_variable from .utils import LeastLoadedDeviceSetter, override_to_local_variable
...@@ -32,7 +34,7 @@ class GraphBuilder(object): ...@@ -32,7 +34,7 @@ class GraphBuilder(object):
class SimpleBuilder(GraphBuilder): class SimpleBuilder(GraphBuilder):
""" """
Build the graph for single-cost single-optimizer single-tower training. Single-cost single-optimizer single-tower training.
""" """
def build(self, input, get_cost_fn, get_opt_fn): def build(self, input, get_cost_fn, get_opt_fn):
""" """
...@@ -133,7 +135,8 @@ class DataParallelBuilder(GraphBuilder): ...@@ -133,7 +135,8 @@ class DataParallelBuilder(GraphBuilder):
@staticmethod @staticmethod
def _make_fn(input, get_cost_fn, get_opt_fn): def _make_fn(input, get_cost_fn, get_opt_fn):
# internal use only # internal use only
assert input.setup_done() assert input.setup_done(), "InputSource must have been setup before calling GraphBuilder!"
assert isinstance(input, FeedfreeInput), input
get_opt_fn = memoized(get_opt_fn) get_opt_fn = memoized(get_opt_fn)
def get_grad_fn(): def get_grad_fn():
...@@ -153,7 +156,7 @@ class DataParallelBuilder(GraphBuilder): ...@@ -153,7 +156,7 @@ class DataParallelBuilder(GraphBuilder):
class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
""" """
Graph builder for data-parallel training in 'ParameterServer' mode. Data-parallel training in 'ParameterServer' mode.
It builds one tower on each GPU with It builds one tower on each GPU with
shared variable scope. It synchronoizes the gradients computed shared variable scope. It synchronoizes the gradients computed
from each tower, averages them and applies to the shared variables. from each tower, averages them and applies to the shared variables.
...@@ -234,7 +237,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): ...@@ -234,7 +237,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
""" """
Graph builder for data-parallel training in "replicated" mode, Data-parallel training in "replicated" mode,
where each GPU contains a replicate of the whole model. where each GPU contains a replicate of the whole model.
It will build one tower on each GPU under its own variable scope. It will build one tower on each GPU under its own variable scope.
Each gradient update is averaged across or GPUs through NCCL. Each gradient update is averaged across or GPUs through NCCL.
...@@ -338,7 +341,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -338,7 +341,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
class AsyncMultiGPUBuilder(DataParallelBuilder): class AsyncMultiGPUBuilder(DataParallelBuilder):
""" """
Graph builder for data-parallel training with async update. Data-parallel training with async update.
It builds one tower on each GPU with shared variable scope. It builds one tower on each GPU with shared variable scope.
Every tower computes the gradients and independently applies them to the Every tower computes the gradients and independently applies them to the
variables, without synchronizing and averaging across towers. variables, without synchronizing and averaging across towers.
......
...@@ -19,35 +19,9 @@ __all__ = ['DistributedTrainerReplicated'] ...@@ -19,35 +19,9 @@ __all__ = ['DistributedTrainerReplicated']
class DistributedTrainerReplicated(Trainer): class DistributedTrainerReplicated(Trainer):
"""
Build the graph with :class:`DistributedReplicatedBuilder` and train it. __doc__ = DistributedReplicatedBuilder.__doc__
Note:
Gradients are not averaged across workers, but applied to PS variables
directly (either with or without locking depending on the optimizer).
Example:
.. code-block:: python
hosts = ['host1.com', 'host2.com']
cluster_spec = tf.train.ClusterSpec({
'ps': [h + ':2222' for h in hosts],
'worker': [h + ':2223' for h in hosts]
})
server = tf.train.Server(
cluster_spec, job_name=args.job, task_index=args.task,
config=get_default_sess_config())
DistributedTrainerReplicated(config, server).train()
.. code-block:: none
# start your jobs:
(host1)$ train.py --job worker --task 0
(host1)$ train.py --job ps --task 0
(host2)$ train.py --job worker --task 1
(host2)$ train.py --job ps --task 1
"""
def __init__(self, config, server): def __init__(self, config, server):
""" """
Args: Args:
...@@ -114,7 +88,7 @@ class DistributedTrainerReplicated(Trainer): ...@@ -114,7 +88,7 @@ class DistributedTrainerReplicated(Trainer):
or self._config.session_config is not None: or self._config.session_config is not None:
raise ValueError( raise ValueError(
"Cannot set session_creator or session_config for distributed training! " "Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it with tf.train.Server.") "To use a custom session config, pass it to tf.train.Server.")
self._config.session_creator = get_distributed_session_creator(self.server) self._config.session_creator = get_distributed_session_creator(self.server)
......
...@@ -48,9 +48,8 @@ def apply_prefetch_policy(config, gpu_prefetch=True): ...@@ -48,9 +48,8 @@ def apply_prefetch_policy(config, gpu_prefetch=True):
class SyncMultiGPUTrainerParameterServer(Trainer): class SyncMultiGPUTrainerParameterServer(Trainer):
"""
Build graph with :class:`SyncMultiGPUParameterServerBuilder` and train it. __doc__ = SyncMultiGPUParameterServerBuilder.__doc__
"""
def __init__(self, config, ps_device='gpu', gpu_prefetch=True): def __init__(self, config, ps_device='gpu', gpu_prefetch=True):
""" """
...@@ -86,9 +85,9 @@ def SyncMultiGPUTrainer(config): ...@@ -86,9 +85,9 @@ def SyncMultiGPUTrainer(config):
class SyncMultiGPUTrainerReplicated(Trainer): class SyncMultiGPUTrainerReplicated(Trainer):
"""
Build graph with :class:`SyncMultiGPUReplicatedBuilder` and train it. __doc__ = SyncMultiGPUReplicatedBuilder.__doc__
"""
def __init__(self, config, gpu_prefetch=True): def __init__(self, config, gpu_prefetch=True):
""" """
Args: Args:
...@@ -111,9 +110,9 @@ class SyncMultiGPUTrainerReplicated(Trainer): ...@@ -111,9 +110,9 @@ class SyncMultiGPUTrainerReplicated(Trainer):
class AsyncMultiGPUTrainer(Trainer): class AsyncMultiGPUTrainer(Trainer):
"""
Build graph with :class:`AsyncMultiGPUBuilder` and train it. __doc__ = AsyncMultiGPUBuilder.__doc__
"""
def __init__(self, config, scale_gradient=True): def __init__(self, config, scale_gradient=True):
""" """
Args: Args:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment