Commit 18064a54 authored by Yuxin Wu's avatar Yuxin Wu

Match sync training speed with tf/benchmarks (fix #254)

parent f26b9b59
......@@ -126,8 +126,12 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
add_model_variable(moving_var)
# seems faster than delayed update, but might behave otherwise in distributed settings.
with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output')
# TODO add an option, and maybe enable it for replica mode?
# with tf.control_dependencies([update_op1, update_op2]):
# return tf.identity(xn, name='output')
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op1)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_op2)
return xn
def reshape_for_bn(param, ndims, chan, data_format):
......
......@@ -28,7 +28,8 @@ def describe_model():
table = tabulate(data, headers=['name', 'shape', 'dim'])
size_mb = total * 4 / 1024.0**2
summary_msg = colored(
"\nTotal #param={} ({:.02f} MB assuming all float32)".format(total, size_mb), 'cyan')
"\nTotal #vars={}, #param={} ({:.02f} MB assuming all float32)".format(
len(data), total, size_mb), 'cyan')
logger.info(colored("Model Parameters: \n", 'cyan') + table + summary_msg)
......
......@@ -19,14 +19,12 @@ class TowerContext(object):
"""
Args:
tower_name (str): 'tower0', 'towerp0', or ''
device (str): the device to use. Defaults to either cpu0 or gpu0.
device (str or device function): the device to use. Defaults to either cpu0 or gpu0.
is_training (bool): if None, automatically determine from tower_name.
"""
self._name = tower_name
if device is None:
device = '/gpu:0' if tf.test.is_gpu_available() else '/cpu:0'
assert self.index == int(device[-1]), \
"Tower name {} and device {} mismatch!".format(self._name, device)
self._device = device
if is_training is None:
......
......@@ -5,6 +5,7 @@
import tensorflow as tf
import itertools
import operator
import re
from six.moves import zip, range
......@@ -25,25 +26,25 @@ __all__ = ['SyncMultiGPUTrainer', 'AsyncMultiGPUTrainer']
class MultiGPUTrainer(Trainer):
""" Base class for multi-gpu training"""
@staticmethod
def multi_tower_grads(towers, get_tower_grad_func):
""" ret[i] is a lists of (grad,var) tuple for tower i"""
return MultiGPUTrainer._build_on_multi_tower(towers, get_tower_grad_func)
@staticmethod
def multi_tower_costs(towers, get_tower_cost_func):
return MultiGPUTrainer._build_on_multi_tower(towers, get_tower_cost_func)
@staticmethod
def _build_on_multi_tower(towers, func):
def build_on_multi_tower(towers, func, devices=None):
"""
Args:
towers: list of gpu relative ids
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in towers.
"""
logger.info("Training a model of {} tower".format(len(towers)))
ret = []
global_scope = tf.get_variable_scope()
if devices is not None:
assert len(devices) == len(towers)
for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
with tf.variable_scope(global_scope, reuse=idx > 0), \
TowerContext(
'tower{}'.format(idx),
device='/gpu:{}'.format(t),
device=device,
is_training=True):
logger.info("Building graph for training tower {}...".format(idx))
......@@ -56,17 +57,46 @@ class MultiGPUTrainer(Trainer):
return ret
class SyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer):
# Copied from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
class ParamServerDeviceSetter(object):
"""Helper class to assign variables on the least loaded ps-device."""
def __init__(self, worker_device, ps_devices):
"""
Args:
worker_device: the device to use for computer ops.
ps_devices: a list of device to use for Variable ops. Each variable is
assigned to the least loaded device.
"""
self.ps_devices = ps_devices
self.worker_device = worker_device
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2']:
return self.worker_device
device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
self.ps_sizes[device_index] += var_size
return device_name
class SyncMultiGPUTrainerParameterServer(MultiGPUTrainer, SingleCostFeedfreeTrainer):
"""
A multi-tower multi-GPU trainer which synchronoizes the gradients computed
from each tower and averages them.
from each tower, averages them and update to variables stored on PS.
"""
def __init__(self, config):
def __init__(self, config, ps_device='gpu'):
"""
Args:
config: same as in :class:`QueueInputTrainer`.
ps_device: either 'gpu' or 'cpu'
"""
if config.dataflow is not None:
# use queueinput by default. May need to avoid this in the future (when more input type is available)
......@@ -74,58 +104,77 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
else:
self._input_method = config.data
assert len(config.tower) >= 1, "MultiGPUTrainer must be used with at least one tower."
if len(config.tower) > 1:
assert tf.test.is_gpu_available()
# doens't seem to improve on single GPU
# seem to only improve on >1 GPUs
if not isinstance(self._input_method, StagingInputWrapper):
devices = ['/gpu:{}'.format(k) for k in config.tower]
self._input_method = StagingInputWrapper(self._input_method, devices)
super(SyncMultiGPUTrainer, self).__init__(config)
assert ps_device in ['gpu', 'cpu'], ps_device
self._ps_device = ps_device
super(SyncMultiGPUTrainerParameterServer, self).__init__(config)
@staticmethod
def _average_grads(tower_grads):
if len(tower_grads) == 1:
nr_tower = len(tower_grads)
if nr_tower == 1:
return tower_grads[0]
ret = []
new_tower_grads = []
with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*tower_grads):
# Ngpu * 2
v = grad_and_vars[0][1]
all_grad = [k[0] for k in grad_and_vars]
all_grads = [g for (g, _) in grad_and_vars]
nones = list(set(all_grad))
nones = list(set(all_grads))
if None in nones and len(nones) != 1:
raise RuntimeError("Gradient w.r.t {} is None in some but not all towers!".format(v.name))
elif nones[0] is None:
logger.warn("No Gradient w.r.t {}".format(v.op.name))
continue
try:
grad = tf.add_n(all_grad) / float(len(tower_grads))
with tf.device(v.device): # colocate summed grad with var
grad = tf.multiply(tf.add_n(all_grads), 1.0 / nr_tower)
except:
logger.error("Error while processing gradients of {}".format(v.name))
raise
ret.append((grad, v))
return ret
new_tower_grads.append((grad, v))
return new_tower_grads
def _setup(self):
super(SyncMultiGPUTrainer, self)._setup()
super(SyncMultiGPUTrainerParameterServer, self)._setup()
grad_list = MultiGPUTrainer.multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
if self._ps_device == 'gpu':
devices = [ParamServerDeviceSetter(d, raw_devices) for d in raw_devices]
else:
devices = [tf.train.replica_device_setter(
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
grad_list = MultiGPUTrainer.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1], devices)
# debug tower performance (without update):
# ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
# self.train_op = tf.group(*ops)
# return
grads = SyncMultiGPUTrainer._average_grads(grad_list)
grads = SyncMultiGPUTrainerParameterServer._average_grads(grad_list)
# grads = grad_list[0]
self.train_op = self.config.optimizer.apply_gradients(grads, name='min_op')
def SyncMultiGPUTrainer(config):
"""
Alias for ``SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')``,
as this is the most commonly used synchronous multigpu trainer.
"""
return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')
class AsyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer):
"""
......@@ -155,7 +204,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def _setup(self):
super(AsyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer.multi_tower_grads(
grad_list = MultiGPUTrainer.build_on_multi_tower(
self.config.tower, lambda: self._get_cost_and_grad()[1])
grad_list = [FilterNoneGrad().process(gv) for gv in grad_list]
if self._scale_gradient and self.config.nr_tower > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment