Commit 6bee3c24 authored by Yuxin Wu's avatar Yuxin Wu

Add overridecachingdevice & colocate option

parent 6efe0deb
...@@ -14,7 +14,7 @@ from ..tfutils.gradproc import ScaleGradient ...@@ -14,7 +14,7 @@ from ..tfutils.gradproc import ScaleGradient
from .utils import ( from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable, LeastLoadedDeviceSetter, override_to_local_variable,
allreduce_grads, average_grads_with_colocation) allreduce_grads, average_grads)
__all__ = ['GraphBuilder', __all__ = ['GraphBuilder',
...@@ -109,16 +109,13 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): ...@@ -109,16 +109,13 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
It is an equivalent of ``--variable_update=parameter_server`` in It is an equivalent of ``--variable_update=parameter_server`` in
`tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_. `tensorflow/benchmarks <https://github.com/tensorflow/benchmarks>`_.
""" """
def __init__(self, towers, ps_device=None): def __init__(self, towers, ps_device):
""" """
Args: Args:
towers(list[int]): list of GPU id towers(list[int]): list of GPU id
ps_device (str): either 'gpu' or 'cpu', where variables are stored. ps_device (str): either 'gpu' or 'cpu', where variables are stored.
Setting to 'cpu' might help when #gpu>=4
""" """
super(SyncMultiGPUParameterServerBuilder, self).__init__(towers) super(SyncMultiGPUParameterServerBuilder, self).__init__(towers)
if ps_device is None:
ps_device = 'cpu' if len(towers) >= 4 else 'gpu'
assert ps_device in ['cpu', 'gpu'] assert ps_device in ['cpu', 'gpu']
self.ps_device = ps_device self.ps_device = ps_device
...@@ -146,7 +143,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): ...@@ -146,7 +143,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
# self.train_op = tf.group(*ops) # self.train_op = tf.group(*ops)
# return # return
grads = average_grads_with_colocation(grad_list) grads = average_grads(grad_list, colocate=True)
# grads = grad_list[0] # grads = grad_list[0]
opt = get_opt_fn() opt = get_opt_fn()
......
...@@ -8,9 +8,11 @@ import operator ...@@ -8,9 +8,11 @@ import operator
import tensorflow as tf import tensorflow as tf
__all__ = ['LeastLoadedDeviceSetter', 'OverrideToLocalVariable', __all__ = ['LeastLoadedDeviceSetter',
'override_to_local_variable', 'allreduce_grads', 'OverrideCachingDevice',
'average_grads_with_colocation'] 'OverrideToLocalVariable', 'override_to_local_variable',
'allreduce_grads',
'average_grads']
""" """
...@@ -115,13 +117,14 @@ def allreduce_grads(all_grads): ...@@ -115,13 +117,14 @@ def allreduce_grads(all_grads):
return ret return ret
def average_grads_with_colocation(all_grads): def average_grads(all_grads, colocation=True):
""" """
Average the gradients, on the device of each variable. Average the gradients, on the device of each variable.
Args: Args:
all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples. all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples.
The variables have to be the same across the K lists. The variables have to be the same across the K lists.
colocation (bool): colocate gradient averaging with the variable
Returns: Returns:
(N x 2): A list of N (grad, var) tuples, where grad is averaged over K. (N x 2): A list of N (grad, var) tuples, where grad is averaged over K.
...@@ -137,8 +140,47 @@ def average_grads_with_colocation(all_grads): ...@@ -137,8 +140,47 @@ def average_grads_with_colocation(all_grads):
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
grads = [g for (g, _) in grad_and_vars] grads = [g for (g, _) in grad_and_vars]
if colocation:
with tf.device(v.device): # colocate summed grad with var with tf.device(v.device): # colocate summed grad with var
grad = tf.multiply( grad = tf.multiply(
tf.add_n(grads), 1.0 / nr_tower) tf.add_n(grads), 1.0 / nr_tower)
else:
grad = tf.multiply(
tf.add_n(grads), 1.0 / nr_tower)
ret.append((grad, v)) ret.append((grad, v))
return ret return ret
# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L140-L166
class OverrideCachingDevice(object):
"""Variable getter which caches variables on the least loaded device.
Variables smaller than a certain threshold are cached on a single specific
device, as specified in the constructor. All other variables are load balanced
across a pool of devices, by caching each variable on the least loaded device.
"""
def __init__(self, devices, device_for_small_variables,
small_variable_size_threshold):
self.devices = devices
self.sizes = [0] * len(self.devices)
self.device_for_small_variables = device_for_small_variables
self.small_variable_size_threshold = small_variable_size_threshold
def __call__(self, getter, *args, **kwargs):
size = tf.TensorShape(kwargs['shape']).num_elements()
if size is None:
# print(args, kwargs)
return getter(*args, **kwargs)
if kwargs.get('trainable', True) == False:
return getter(*args, **kwargs)
if size < self.small_variable_size_threshold:
device_name = self.device_for_small_variables
else:
device_index, _ = min(enumerate(self.sizes), key=operator.itemgetter(1))
device_name = self.devices[device_index]
self.sizes[device_index] += size
kwargs['caching_device'] = device_name
var = getter(*args, **kwargs)
return var
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment