Commit b0677681 authored by Yuxin Wu's avatar Yuxin Wu

split variable strategies into methods

parent 9fd5cb9f
......@@ -20,6 +20,7 @@ from ..callbacks.monitor import Monitors
__all__ = ['DistributedReplicatedTrainer']
# Note that only trainable vars are shadowed
PS_SHADOW_VAR_PREFIX = 'ps_var'
......@@ -83,8 +84,12 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
def _average_grads(tower_grads, devices):
"""
Average grad with round-robin device selection.
Args:
tower_grads: Ngpu x Nvar x 2
Returns:
Nvar x 2
"""
nr_device = len(devices)
if nr_device == 1:
......@@ -104,6 +109,46 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
new_tower_grads.append((grad, v))
return new_tower_grads
@staticmethod
def _apply_shadow_vars(avg_grads):
"""
Replace variables in avg_grads by shadow variables.
"""
ps_var_grads = []
for grad, var in avg_grads:
my_name = PS_SHADOW_VAR_PREFIX + '/' + var.name
my_name = get_op_tensor_name(my_name)[0]
new_v = tf.get_variable(my_name, dtype=var.dtype.base_dtype,
initializer=var.initial_value,
trainable=True)
# (g, v) to be applied, where v is global (ps vars)
ps_var_grads.append((grad, new_v))
return ps_var_grads
def _apply_gradients_and_copy(self, raw_grad_list, ps_var_grads):
"""
Args:
raw_grad_list: Ngpu x Nvar x 2 gradient list from all towers
ps_var_grads: Nvar x 2 (grad, ps_var)
Returns:
list of copy ops
"""
# TODO do this for each variable separately?
opt = self.model.get_optimizer() # TODO ensure it in global scope, not local
var_update_ops = []
for vid, (g, v) in enumerate(ps_var_grads):
apply_gradient_op = opt.apply_gradients([(g, v)])
barrier = self.add_sync_queues_and_barrier(
'param_update_barrier_{}'.format(vid), [apply_gradient_op])
with tf.control_dependencies([barrier]), \
tf.device(self.cpu_device):
updated_value = v.read_value()
for towerid in range(self.nr_gpu):
var_update_ops.append(
raw_grad_list[towerid][vid][1].assign(updated_value))
return var_update_ops
def _setup(self):
conf = get_default_sess_config()
self.server = tf.train.Server(
......@@ -128,35 +173,12 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
var_strategy='replicated')
avg_grads = DistributedReplicatedTrainer._average_grads(grad_list, self.raw_devices)
# Nvar * 2
ps_var_grads = []
for i, (grad, var) in enumerate(avg_grads):
with tf.device(self.param_server_device):
my_name = PS_SHADOW_VAR_PREFIX + '/' + var.name
my_name = get_op_tensor_name(my_name)[0]
new_v = tf.get_variable(my_name, dtype=var.dtype.base_dtype,
initializer=var.initial_value,
trainable=True)
# (g, v) to be applied, where v is global (ps vars)
ps_var_grads.append((grad, new_v))
# apply gradients TODO do this for each variable separately?
var_update_ops = []
with tf.device(self.param_server_device):
for vid, (g, v) in enumerate(ps_var_grads):
apply_gradient_op = opt.apply_gradients([(g, v)])
barrier = self.add_sync_queues_and_barrier(
'param_update_barrier_{}'.format(vid), [apply_gradient_op])
with tf.control_dependencies([barrier]), \
tf.device(self.cpu_device):
updated_value = v.read_value()
for towerid in range(self.nr_gpu):
logger.info("Step update {} -> {}".format(v.name, grad_list[towerid][vid][1].name))
var_update_ops.append(
grad_list[towerid][vid][1].assign(updated_value))
self.main_fetch = tf.group(*var_update_ops, name='main_fetches')
#self.train_op = self.main_fetch
self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [self.main_fetch])
ps_var_grads = DistributedReplicatedTrainer._apply_shadow_vars(avg_grads)
var_update_ops = self._apply_gradients_and_copy(grad_list, ps_var_grads)
main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [main_fetch])
self.post_init_op = self.get_post_init_ops()
def setup(self):
......@@ -185,10 +207,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
summary_op=None,
save_model_secs=0,
summary_writer=None)
conf = get_default_sess_config()
sess = self.sv.prepare_or_wait_for_session(
master=self.server.target,
config=conf,
start_standard_services=False)
self.sess = sess
......@@ -198,7 +218,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
self._monitored_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=None)
#self._monitored_sess = self.sv
hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks)
......@@ -213,7 +232,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
an op that should be used as control dependency before starting next step.
"""
self.sync_queue_counter += 1
self.num_worker = self.cluster.num_tasks('worker')
with tf.device(self.sync_queue_devices[self.sync_queue_counter % len(self.sync_queue_devices)]):
sync_queues = [
tf.FIFOQueue(self.num_worker, [tf.bool], shapes=[[]],
......@@ -257,7 +275,6 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
name = 'tower%s/%s' % (i, prefix)
if name in local_var_by_name:
copy_to = local_var_by_name[name]
logger.info("Post Init {} -> {}".format(v.name, copy_to.name))
post_init_ops.append(copy_to.assign(v.read_value()))
else:
logger.warn("Global var {} doesn't match local var".format(v.name))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment