Commit 1f5c764d authored by Yuxin Wu's avatar Yuxin Wu

when ps_device=cpu, apply gradients on cpu

parent 3f1e9a14
......@@ -32,6 +32,7 @@ class TowerContext(object):
self._index = int(index)
if use_vs:
self._vs_name = self._name
assert len(self._name)
else:
self._vs_name = ''
......
......@@ -157,13 +157,13 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
return new_tower_grads
@staticmethod
def setup_graph(model, input, ps_device, tower):
def setup_graph(model, input, ps_device, towers):
"""
Args:
model (ModelDesc):
input (InputSource):
ps_device (str):
tower (list[int]):
towers (list[int]):
Returns:
tf.Operation: the training op
......@@ -172,7 +172,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
"""
callbacks = input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower]
raw_devices = ['/gpu:{}'.format(k) for k in towers]
if ps_device == 'gpu':
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
else:
......@@ -180,7 +180,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
tower,
towers,
lambda: MultiGPUTrainerBase._build_graph_get_grads(model, input),
devices)
MultiGPUTrainerBase._check_grad_list(grad_list)
......@@ -193,7 +193,12 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
grads = SyncMultiGPUTrainerParameterServer._average_grads(grad_list)
# grads = grad_list[0]
train_op = model.get_optimizer().apply_gradients(grads, name='train_op')
opt = model.get_optimizer()
if ps_device == 'cpu':
with tf.device('/cpu:0'):
train_op = opt.apply_gradients(grads, name='train_op')
else:
train_op = opt.apply_gradients(grads, name='train_op')
return train_op, callbacks
def _setup(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment