Commit 93819550 authored by Yuxin Wu's avatar Yuxin Wu

Use `call_for_each_tower` just like DistStrat

parent eb88cda0
......@@ -140,16 +140,12 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
assert ps_device in ['cpu', 'gpu']
self.ps_device = ps_device
def build(self, get_grad_fn, get_opt_fn):
def call_for_each_tower(self, tower_fn):
"""
Build the graph, and set self.grads to a list of (g, v), containing the averaged gradients.
Args:
get_grad_fn (-> [(grad, var)]):
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Call the function `tower_fn` under :class:`TowerContext` for each tower.
Returns:
tf.Operation: the training op
a list, contains the return values of `tower_fn` on each tower.
"""
raw_devices = ['/gpu:{}'.format(k) for k in self.towers]
if self.ps_device == 'gpu':
......@@ -158,7 +154,21 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
devices = [tf.train.replica_device_setter(
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grad_fn, devices)
return DataParallelBuilder.build_on_towers(self.towers, tower_fn, devices)
def build(self, grad_list, get_opt_fn):
"""
Reduce the gradients, apply them with the optimizer,
and set self.grads to a list of (g, v), containing the averaged gradients.
Args:
grad_list ([[(grad, var), ...], ...]): #GPU lists to be reduced. Each is the gradients computed on each GPU.
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns:
tf.Operation: the training op
"""
assert len(grad_list) == len(self.towers)
DataParallelBuilder._check_grad_list(grad_list)
# debug tower performance (without update):
......@@ -195,13 +205,31 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
assert mode in ['nccl', 'cpu', 'hierarchical'], mode
self._mode = mode
def build(self, get_grad_fn, get_opt_fn):
if self._mode == 'hierarchical' and len(towers) != 8:
logger.warn("mode='hierarchical' require >= 8 GPUs. Fallback to mode='nccl'.")
self._mode = 'nccl'
def call_for_each_tower(self, tower_fn):
"""
Call the function `tower_fn` under :class:`TowerContext` for each tower.
Returns:
a list, contains the return values of `tower_fn` on each tower.
"""
Build the graph, and set self.grads to #GPU number of lists of (g, v), containing the
all-reduced gradients on each device.
# if tower_fn returns [(grad, var), ...], this returns #GPU x #VAR x 2
return DataParallelBuilder.build_on_towers(
self.towers,
tower_fn,
# use no variable scope for the first tower
use_vs=[False] + [True] * (len(self.towers) - 1))
def build(self, grad_list, get_opt_fn):
"""
Reduce the gradients, apply them with the optimizer,
and set self.grads to #GPU number of lists of (g, v), containing the all-reduced gradients on each device.
Args:
get_grad_fn (-> [(grad, var)]):
grad_list ([[(grad, var), ...], ...]): #GPU lists to be reduced. Each is the gradients computed on each GPU.
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns:
......@@ -213,21 +241,11 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
It has to be run before the training has started.
And you can optionally run it later to sync non-trainable variables.
"""
assert len(grad_list) == len(self.towers)
raw_devices = ['/gpu:{}'.format(k) for k in self.towers]
# #GPU x #VAR x 2
grad_list = DataParallelBuilder.build_on_towers(
self.towers,
get_grad_fn,
# use no variable scope for the first tower
use_vs=[False] + [True] * (len(self.towers) - 1))
DataParallelBuilder._check_grad_list(grad_list)
if self._mode == 'hierarchical' and len(raw_devices) < 8:
logger.warn("mode='hierarchical' require >= 8 GPUs. Fallback to mode='cpu'.")
self._mode = 'cpu'
dtypes = set([x[0].dtype.base_dtype for x in grad_list[0]])
dtypes_nccl_supported = [tf.float32, tf.float64]
if get_tf_version_tuple() >= (1, 8):
......@@ -340,14 +358,12 @@ class AsyncMultiGPUBuilder(DataParallelBuilder):
super(AsyncMultiGPUBuilder, self).__init__(towers)
self._scale_gradient = scale_gradient
def build(self, get_grad_fn, get_opt_fn):
def call_for_each_tower(self, tower_fn):
"""
Args:
get_grad_fn (-> [(grad, var)]):
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Call the function `tower_fn` under :class:`TowerContext` for each tower.
Returns:
tf.Operation: the training op
a list, contains the return values of `tower_fn` on each tower.
"""
ps_device = 'cpu' if len(self.towers) >= 4 else 'gpu'
......@@ -358,7 +374,18 @@ class AsyncMultiGPUBuilder(DataParallelBuilder):
devices = [tf.train.replica_device_setter(
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grad_fn, devices)
return DataParallelBuilder.build_on_towers(self.towers, tower_fn, devices)
def build(self, grad_list, get_opt_fn):
"""
Args:
grad_list ([[(grad, var), ...], ...]): #GPU lists to be reduced. Each is the gradients computed on each GPU.
get_opt_fn (-> tf.train.Optimizer): callable which returns an optimizer
Returns:
tf.Operation: the training op
"""
assert len(grad_list) == len(self.towers)
DataParallelBuilder._check_grad_list(grad_list)
if self._scale_gradient and len(self.towers) > 1:
......
......@@ -102,8 +102,9 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
if len(self.devices) > 1:
assert isinstance(input, FeedfreeInput), input
self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
tower_fn = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)
grad_list = self._builder.call_for_each_tower(tower_fn)
self.train_op = self._builder.build(grad_list, get_opt_fn)
return []
......@@ -141,8 +142,9 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
if len(self.devices) > 1:
assert isinstance(input, FeedfreeInput), input
self.train_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
tower_fn = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn),
grad_list = self._builder.call_for_each_tower(tower_fn)
self.train_op = self._builder.build(grad_list, get_opt_fn)
return []
......@@ -183,8 +185,9 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
if len(self.devices) > 1:
assert isinstance(input, FeedfreeInput), input
self.train_op, post_init_op = self._builder.build(
self._make_get_grad_fn(input, get_cost_fn, get_opt_fn), get_opt_fn)
tower_fn = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)
grad_list = self._builder.call_for_each_tower(tower_fn)
self.train_op, post_init_op = self._builder.build(grad_list, get_opt_fn)
cb = RunOp(
post_init_op,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment