Commit 9fd5cb9f authored by Yuxin Wu's avatar Yuxin Wu

fix batchrenorm, simplify dist-trainer code

parent 86d1b2e5
...@@ -15,7 +15,7 @@ from .multigpu import MultiGPUTrainerBase ...@@ -15,7 +15,7 @@ from .multigpu import MultiGPUTrainerBase
from ..tfutils.model_utils import describe_model from ..tfutils.model_utils import describe_model
from ..callbacks import Callbacks, ProgressBar from ..callbacks import Callbacks, ProgressBar
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.common import get_default_sess_config, get_global_step_var from ..tfutils.common import get_default_sess_config, get_global_step_var, get_op_tensor_name
from ..callbacks.monitor import Monitors from ..callbacks.monitor import Monitors
__all__ = ['DistributedReplicatedTrainer'] __all__ = ['DistributedReplicatedTrainer']
...@@ -50,7 +50,6 @@ class OverrideToLocalVariableIfNotPsVar(object): ...@@ -50,7 +50,6 @@ class OverrideToLocalVariableIfNotPsVar(object):
class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
def __init__(self, config, job_name, task_index, cluster): def __init__(self, config, job_name, task_index, cluster):
assert job_name in ['ps', 'worker'], job_name assert job_name in ['ps', 'worker'], job_name
self.config = config
self.job_name = job_name self.job_name = job_name
self.task_index = task_index self.task_index = task_index
self.cluster = cluster self.cluster = cluster
...@@ -61,14 +60,16 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -61,14 +60,16 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
worker_prefix = '/job:worker/task:%s' % self.task_index worker_prefix = '/job:worker/task:%s' % self.task_index
self.param_server_device = tf.train.replica_device_setter( self.param_server_device = tf.train.replica_device_setter(
worker_device=worker_prefix + '/cpu:0', cluster=self.cluster) worker_device=worker_prefix + '/cpu:0', cluster=self.cluster)
# This device on which the queues for managing synchronization between self.num_ps = self.cluster.num_tasks('ps')
# servers should be stored. self.num_worker = self.cluster.num_tasks('worker')
num_ps = self.cluster.num_tasks('ps')
self.cpu_device = '%s/cpu:0' % worker_prefix
self.nr_gpu = config.nr_tower self.nr_gpu = config.nr_tower
self.cpu_device = '%s/cpu:0' % worker_prefix
self.raw_devices = ['%s/%s:%i' % (worker_prefix, 'gpu', i) for i in range(self.nr_gpu)] self.raw_devices = ['%s/%s:%i' % (worker_prefix, 'gpu', i) for i in range(self.nr_gpu)]
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(num_ps)]
# This device on which the queues for managing synchronization between
# servers should be stored.
self.sync_queue_devices = ['/job:ps/task:%s/cpu:0' % i for i in range(self.num_ps)]
self.sync_queue_counter = 0 self.sync_queue_counter = 0
if self.nr_gpu > 1: if self.nr_gpu > 1:
...@@ -78,6 +79,31 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -78,6 +79,31 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
if not isinstance(self._input_source, StagingInputWrapper): if not isinstance(self._input_source, StagingInputWrapper):
self._input_source = StagingInputWrapper(self._input_source, self.raw_devices) self._input_source = StagingInputWrapper(self._input_source, self.raw_devices)
@staticmethod
def _average_grads(tower_grads, devices):
"""
Average grad with round-robin device selection.
Args:
tower_grads: Ngpu x Nvar x 2
"""
nr_device = len(devices)
if nr_device == 1:
return tower_grads[0]
new_tower_grads = []
with tf.name_scope('AvgGrad'):
for i, grad_and_vars in enumerate(zip(*grad_list)):
# Ngpu * 2
with tf.device(devices[i % nr_device]):
v = grad_and_vars[0][1]
# average gradient
all_grads = [g for (g, _) in grad_and_vars]
if not MultiGPUTrainerBase.check_none_grads(v.op.name, all_grads):
continue
grad = tf.multiply(
tf.add_n(all_grads), 1.0 / nr_device)
new_tower_grads.append((grad, v))
return new_tower_grads
def _setup(self): def _setup(self):
conf = get_default_sess_config() conf = get_default_sess_config()
self.server = tf.train.Server( self.server = tf.train.Server(
...@@ -101,35 +127,23 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -101,35 +127,23 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
devices=self.raw_devices, devices=self.raw_devices,
var_strategy='replicated') var_strategy='replicated')
# (g, v) to be applied, where v is global (ps vars) avg_grads = DistributedReplicatedTrainer._average_grads(grad_list, self.raw_devices)
new_tower_grads = [] # Nvar * 2
for i, grad_and_vars in enumerate(zip(*grad_list)): ps_var_grads = []
# Ngpu * 2 for i, (grad, var) in enumerate(avg_grads):
with tf.device(self.raw_devices[i % self.nr_gpu]):
v = grad_and_vars[0][1]
if self.nr_gpu > 1:
# average gradient
all_grads = [g for (g, _) in grad_and_vars]
if not MultiGPUTrainerBase.check_none_grads(v.op.name, all_grads):
continue
grad = tf.multiply(
tf.add_n(all_grads), 1.0 / self.nr_gpu)
else:
grad = grad_and_vars[0][0]
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
my_name = PS_SHADOW_VAR_PREFIX + '/' + v.name my_name = PS_SHADOW_VAR_PREFIX + '/' + var.name
if my_name.endswith(':0'): my_name = get_op_tensor_name(my_name)[0]
my_name = my_name[:-2] new_v = tf.get_variable(my_name, dtype=var.dtype.base_dtype,
new_v = tf.get_variable(my_name, dtype=v.dtype.base_dtype, initializer=var.initial_value,
initializer=v.initial_value,
trainable=True) trainable=True)
new_tower_grads.append((grad, new_v)) # (g, v) to be applied, where v is global (ps vars)
ps_var_grads.append((grad, new_v))
# apply gradients TODO do this for each variable separately? # apply gradients TODO do this for each variable separately?
var_update_ops = [] var_update_ops = []
with tf.device(self.param_server_device): with tf.device(self.param_server_device):
for vid, (g, v) in enumerate(new_tower_grads): for vid, (g, v) in enumerate(ps_var_grads):
apply_gradient_op = opt.apply_gradients([(g, v)]) apply_gradient_op = opt.apply_gradients([(g, v)])
barrier = self.add_sync_queues_and_barrier( barrier = self.add_sync_queues_and_barrier(
'param_update_barrier_{}'.format(vid), [apply_gradient_op]) 'param_update_barrier_{}'.format(vid), [apply_gradient_op])
...@@ -141,8 +155,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -141,8 +155,8 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
var_update_ops.append( var_update_ops.append(
grad_list[towerid][vid][1].assign(updated_value)) grad_list[towerid][vid][1].assign(updated_value))
self.main_fetch = tf.group(*var_update_ops, name='main_fetches') self.main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self.main_fetch #self.train_op = self.main_fetch
#self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [self.main_fetch]) self.train_op = self.add_sync_queues_and_barrier('sync_queues_step_end', [self.main_fetch])
self.post_init_op = self.get_post_init_ops() self.post_init_op = self.get_post_init_ops()
def setup(self): def setup(self):
...@@ -199,12 +213,12 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer): ...@@ -199,12 +213,12 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
an op that should be used as control dependency before starting next step. an op that should be used as control dependency before starting next step.
""" """
self.sync_queue_counter += 1 self.sync_queue_counter += 1
num_workers = self.cluster.num_tasks('worker') self.num_worker = self.cluster.num_tasks('worker')
with tf.device(self.sync_queue_devices[self.sync_queue_counter % len(self.sync_queue_devices)]): with tf.device(self.sync_queue_devices[self.sync_queue_counter % len(self.sync_queue_devices)]):
sync_queues = [ sync_queues = [
tf.FIFOQueue(num_workers, [tf.bool], shapes=[[]], tf.FIFOQueue(self.num_worker, [tf.bool], shapes=[[]],
shared_name='%s%s' % (name_prefix, i)) shared_name='%s%s' % (name_prefix, i))
for i in range(num_workers)] for i in range(self.num_worker)]
queue_ops = [] queue_ops = []
# For each other worker, add an entry in a queue, signaling that it can # For each other worker, add an entry in a queue, signaling that it can
# finish this step. # finish this step.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment