Commit 6bee3c24 authored by Yuxin Wu's avatar Yuxin Wu

Add overridecachingdevice & colocate option

parent 6efe0deb
......@@ -14,7 +14,7 @@ from ..tfutils.gradproc import ScaleGradient
from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable,
allreduce_grads, average_grads_with_colocation)
allreduce_grads, average_grads)
__all__ = ['GraphBuilder',
......@@ -109,16 +109,13 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
It is an equivalent of ``--variable_update=parameter_server`` in
`tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_.
"""
def __init__(self, towers, ps_device=None):
def __init__(self, towers, ps_device):
"""
Args:
towers(list[int]): list of GPU id
ps_device (str): either 'gpu' or 'cpu', where variables are stored.
Setting to 'cpu' might help when #gpu>=4
"""
super(SyncMultiGPUParameterServerBuilder, self).__init__(towers)
if ps_device is None:
ps_device = 'cpu' if len(towers) >= 4 else 'gpu'
assert ps_device in ['cpu', 'gpu']
self.ps_device = ps_device
......@@ -146,7 +143,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
# self.train_op = tf.group(*ops)
# return
grads = average_grads_with_colocation(grad_list)
grads = average_grads(grad_list, colocate=True)
# grads = grad_list[0]
opt = get_opt_fn()
......
......@@ -8,9 +8,11 @@ import operator
import tensorflow as tf
__all__ = ['LeastLoadedDeviceSetter', 'OverrideToLocalVariable',
'override_to_local_variable', 'allreduce_grads',
'average_grads_with_colocation']
__all__ = ['LeastLoadedDeviceSetter',
'OverrideCachingDevice',
'OverrideToLocalVariable', 'override_to_local_variable',
'allreduce_grads',
'average_grads']
"""
......@@ -115,13 +117,14 @@ def allreduce_grads(all_grads):
return ret
def average_grads_with_colocation(all_grads):
def average_grads(all_grads, colocation=True):
"""
Average the gradients, on the device of each variable.
Args:
all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples.
The variables have to be the same across the K lists.
colocation (bool): colocate gradient averaging with the variable
Returns:
(N x 2): A list of N (grad, var) tuples, where grad is averaged over K.
......@@ -137,8 +140,47 @@ def average_grads_with_colocation(all_grads):
v = grad_and_vars[0][1]
grads = [g for (g, _) in grad_and_vars]
with tf.device(v.device): # colocate summed grad with var
if colocation:
with tf.device(v.device): # colocate summed grad with var
grad = tf.multiply(
tf.add_n(grads), 1.0 / nr_tower)
else:
grad = tf.multiply(
tf.add_n(grads), 1.0 / nr_tower)
ret.append((grad, v))
ret.append((grad, v))
return ret
# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L140-L166
class OverrideCachingDevice(object):
"""Variable getter which caches variables on the least loaded device.
Variables smaller than a certain threshold are cached on a single specific
device, as specified in the constructor. All other variables are load balanced
across a pool of devices, by caching each variable on the least loaded device.
"""
def __init__(self, devices, device_for_small_variables,
small_variable_size_threshold):
self.devices = devices
self.sizes = [0] * len(self.devices)
self.device_for_small_variables = device_for_small_variables
self.small_variable_size_threshold = small_variable_size_threshold
def __call__(self, getter, *args, **kwargs):
size = tf.TensorShape(kwargs['shape']).num_elements()
if size is None:
# print(args, kwargs)
return getter(*args, **kwargs)
if kwargs.get('trainable', True) == False:
return getter(*args, **kwargs)
if size < self.small_variable_size_threshold:
device_name = self.device_for_small_variables
else:
device_index, _ = min(enumerate(self.sizes), key=operator.itemgetter(1))
device_name = self.devices[device_index]
self.sizes[device_index] += size
kwargs['caching_device'] = device_name
var = getter(*args, **kwargs)
return var
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment