Commit 1f5c764d authored by Yuxin Wu's avatar Yuxin Wu

when ps_device=cpu, apply gradients on cpu

parent 3f1e9a14
...@@ -32,6 +32,7 @@ class TowerContext(object): ...@@ -32,6 +32,7 @@ class TowerContext(object):
self._index = int(index) self._index = int(index)
if use_vs: if use_vs:
self._vs_name = self._name self._vs_name = self._name
assert len(self._name)
else: else:
self._vs_name = '' self._vs_name = ''
......
...@@ -157,13 +157,13 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase): ...@@ -157,13 +157,13 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
return new_tower_grads return new_tower_grads
@staticmethod @staticmethod
def setup_graph(model, input, ps_device, tower): def setup_graph(model, input, ps_device, towers):
""" """
Args: Args:
model (ModelDesc): model (ModelDesc):
input (InputSource): input (InputSource):
ps_device (str): ps_device (str):
tower (list[int]): towers (list[int]):
Returns: Returns:
tf.Operation: the training op tf.Operation: the training op
...@@ -172,7 +172,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase): ...@@ -172,7 +172,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
""" """
callbacks = input.setup(model.get_inputs_desc()) callbacks = input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower] raw_devices = ['/gpu:{}'.format(k) for k in towers]
if ps_device == 'gpu': if ps_device == 'gpu':
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
else: else:
...@@ -180,7 +180,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase): ...@@ -180,7 +180,7 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices] worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
tower, towers,
lambda: MultiGPUTrainerBase._build_graph_get_grads(model, input), lambda: MultiGPUTrainerBase._build_graph_get_grads(model, input),
devices) devices)
MultiGPUTrainerBase._check_grad_list(grad_list) MultiGPUTrainerBase._check_grad_list(grad_list)
...@@ -193,7 +193,12 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase): ...@@ -193,7 +193,12 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
grads = SyncMultiGPUTrainerParameterServer._average_grads(grad_list) grads = SyncMultiGPUTrainerParameterServer._average_grads(grad_list)
# grads = grad_list[0] # grads = grad_list[0]
train_op = model.get_optimizer().apply_gradients(grads, name='train_op') opt = model.get_optimizer()
if ps_device == 'cpu':
with tf.device('/cpu:0'):
train_op = opt.apply_gradients(grads, name='train_op')
else:
train_op = opt.apply_gradients(grads, name='train_op')
return train_op, callbacks return train_op, callbacks
def _setup(self): def _setup(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment