Commit 6d954998 authored by Yuxin Wu's avatar Yuxin Wu

update docs for distributed trainer (#375)

parent 864a35f6
......@@ -45,7 +45,30 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
See https://www.tensorflow.org/performance/benchmarks for details.
Note:
Gradients are not averaged across workers.
Gradients are not averaged across workers, but applied to PS variables
directly (either with or without locking depending on the optimizer).
Example:
.. code-block:: python
hosts = ['host1.com', 'host2.com']
cluster_spec = tf.train.ClusterSpec({
'ps': [h + ':2222' for h in hosts],
'worker': [h + ':2223' for h in hosts]
})
server = tf.train.Server(
cluster_spec, job_name=args.job, task_index=args.task,
config=get_default_sess_config())
DistributedTrainerReplicated(config, server).train()
.. code-block::
# start your jobs:
(host1)$ train.py --job worker --task 0
(host1)$ train.py --job ps --task 0
(host2)$ train.py --job worker --task 1
(host2)$ train.py --job ps --task 1
"""
def __init__(self, config, server):
"""
......@@ -61,6 +84,8 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
self.task_index = server_def.task_index
assert self.job_name in ['ps', 'worker'], self.job_name
assert tf.test.is_gpu_available
logger.info("Distributed training on cluster:\n" + str(server_def.cluster))
logger.info("My role in the cluster: job={}, task={}".format(self.job_name, self.task_index))
self._input_source = config.data
self.is_chief = (self.task_index == 0 and self.job_name == 'worker')
......@@ -112,7 +137,8 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
@staticmethod
def _apply_shadow_vars(avg_grads):
"""
Replace variables in avg_grads by shadow variables.
Create shadow variables on PS, and replace variables in avg_grads
by these shadow variables.
Args:
avg_grads: list of (grad, var) tuples
......@@ -156,6 +182,9 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
def _apply_gradients_and_copy(self, raw_grad_list, ps_var_grads):
"""
Apply averaged gradients to ps vars, and then copy the updated
variables back to each tower.
Args:
raw_grad_list: Ngpu x Nvar x 2 gradient list from all towers
ps_var_grads: Nvar x 2 (grad, ps_var)
......@@ -226,7 +255,7 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
cb = RunOp(self._get_sync_model_vars_op,
run_before=False, run_as_trigger=True, verbose=True)
logger.warn("For efficiency, local MODEL_VARIABLES are only synced to PS once "
"every epoch. Be careful if you save the model more frequenctly.")
"every epoch. Be careful if you save the model more frequently than this.")
self.register_callback(cb)
self._set_session_creator()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment