Commit d3e0a688 authored by Yuxin Wu's avatar Yuxin Wu

rename `average_grads

parent 363777e4
......@@ -14,7 +14,7 @@ from ..tfutils.gradproc import ScaleGradient
from .utils import (
LeastLoadedDeviceSetter, override_to_local_variable,
allreduce_grads, average_grads)
allreduce_grads, average_grads_with_colocation)
__all__ = ['GraphBuilder',
......@@ -145,7 +145,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
# self.train_op = tf.group(*ops)
# return
grads = average_grads(grad_list)
grads = average_grads_with_colocation(grad_list)
# grads = grad_list[0]
opt = get_opt_fn()
......
......@@ -9,7 +9,8 @@ import tensorflow as tf
__all__ = ['LeastLoadedDeviceSetter', 'OverrideToLocalVariable',
'override_to_local_variable', 'allreduce_grads', 'average_grads']
'override_to_local_variable', 'allreduce_grads',
'average_grads_with_colocation']
"""
......@@ -114,7 +115,7 @@ def allreduce_grads(all_grads):
return ret
def average_grads(all_grads):
def average_grads_with_colocation(all_grads):
"""
Average the gradients, on the device of each variable.
......
......@@ -95,7 +95,7 @@ def SyncMultiGPUTrainer(gpus):
Args:
gpus (list[int]): list of GPU ids.
"""
return SyncMultiGPUTrainerParameterServer(gpus, ps_device='gpu')
return SyncMultiGPUTrainerParameterServer(gpus, ps_device='cpu')
class AsyncMultiGPUTrainer(SingleCostTrainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment