Commit d3e0a688 authored by Yuxin Wu's avatar Yuxin Wu

rename `average_grads

parent 363777e4
...@@ -14,7 +14,7 @@ from ..tfutils.gradproc import ScaleGradient ...@@ -14,7 +14,7 @@ from ..tfutils.gradproc import ScaleGradient
from .utils import ( from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable, LeastLoadedDeviceSetter, override_to_local_variable,
allreduce_grads, average_grads) allreduce_grads, average_grads_with_colocation)
__all__ = ['GraphBuilder', __all__ = ['GraphBuilder',
...@@ -145,7 +145,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): ...@@ -145,7 +145,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
# self.train_op = tf.group(*ops) # self.train_op = tf.group(*ops)
# return # return
grads = average_grads(grad_list) grads = average_grads_with_colocation(grad_list)
# grads = grad_list[0] # grads = grad_list[0]
opt = get_opt_fn() opt = get_opt_fn()
......
...@@ -9,7 +9,8 @@ import tensorflow as tf ...@@ -9,7 +9,8 @@ import tensorflow as tf
__all__ = ['LeastLoadedDeviceSetter', 'OverrideToLocalVariable', __all__ = ['LeastLoadedDeviceSetter', 'OverrideToLocalVariable',
'override_to_local_variable', 'allreduce_grads', 'average_grads'] 'override_to_local_variable', 'allreduce_grads',
'average_grads_with_colocation']
""" """
...@@ -114,7 +115,7 @@ def allreduce_grads(all_grads): ...@@ -114,7 +115,7 @@ def allreduce_grads(all_grads):
return ret return ret
def average_grads(all_grads): def average_grads_with_colocation(all_grads):
""" """
Average the gradients, on the device of each variable. Average the gradients, on the device of each variable.
......
...@@ -95,7 +95,7 @@ def SyncMultiGPUTrainer(gpus): ...@@ -95,7 +95,7 @@ def SyncMultiGPUTrainer(gpus):
Args: Args:
gpus (list[int]): list of GPU ids. gpus (list[int]): list of GPU ids.
""" """
return SyncMultiGPUTrainerParameterServer(gpus, ps_device='gpu') return SyncMultiGPUTrainerParameterServer(gpus, ps_device='cpu')
class AsyncMultiGPUTrainer(SingleCostTrainer): class AsyncMultiGPUTrainer(SingleCostTrainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment