Commit 1d99dc4e authored by Yuxin Wu's avatar Yuxin Wu

organize name scopes in trainers

parent 5f750f13
......@@ -4,7 +4,7 @@
import tensorflow as tf
import re
from six.moves import zip, range
from six.moves import range
from ..utils.argtools import memoized
from ..tfutils.common import get_op_tensor_name, get_global_step_var
......@@ -194,32 +194,6 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
# Device for queues for managing synchronization between servers
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
@staticmethod
def _average_grads(tower_grads, devices):
"""
Average grads from towers.
The device where the average happens is chosen with round-robin.
Args:
tower_grads: Ngpu x Nvar x 2
Returns:
Nvar x 2
"""
nr_device = len(devices)
if nr_device == 1:
return tower_grads[0]
new_tower_grads = []
with tf.name_scope('AvgGrad'):
for i, grad_and_vars in enumerate(zip(*tower_grads)):
v = grad_and_vars[0][1] # Ngpu * 2
all_grads = [g for (g, _) in grad_and_vars]
with tf.device(devices[i % nr_device]):
grad = tf.multiply(
tf.add_n(all_grads), 1.0 / nr_device)
new_tower_grads.append((grad, v))
return new_tower_grads
@staticmethod
def _apply_shadow_vars(avg_grads):
"""
......@@ -298,7 +272,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
use_vs=[True] * len(self.towers)) # open vs at each tower
DataParallelBuilder._check_grad_list(grad_list)
avg_grads = DistributedReplicatedBuilder._average_grads(grad_list, self.raw_devices)
avg_grads = average_grads(grad_list, devices=self.raw_devices)
with tf.device(self.param_server_device):
ps_var_grads = DistributedReplicatedBuilder._apply_shadow_vars(avg_grads)
var_update_ops = self._apply_gradients_and_copy(
......@@ -312,8 +286,10 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
'post_copy_barrier', [main_fetch])
# initial local_vars syncing
with tf.name_scope('initial_sync_variables'):
initial_sync_op = self._get_initial_sync_op()
if len(self._shadow_model_vars) and self.is_chief:
with tf.name_scope('sync_model_variables'):
model_sync_op = self._get_sync_model_vars_op()
else:
model_sync_op = None
......@@ -332,6 +308,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
list of copy ops
"""
# TODO do this for variables together?
with tf.name_scope('apply_gradients'):
var_update_ops = []
for vid, (g, v) in enumerate(ps_var_grads):
# TODO do we put momentum variables into local or global?
......
......@@ -218,6 +218,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
train_ops = []
opt = get_opt_fn()
with tf.name_scope('apply_gradients'):
for idx, grad_and_vars in enumerate(grads):
with tf.device(raw_devices[idx]):
# apply_gradients may create variables. Make them LOCAL_VARIABLES
......@@ -226,6 +227,7 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
grad_and_vars, name='apply_grad_{}'.format(idx)))
train_op = tf.group(*train_ops, name='train_op')
with tf.name_scope('sync_variables'):
post_init_op = SyncMultiGPUReplicatedBuilder.get_post_init_ops()
return train_op, post_init_op
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment