Commit ce709fa3 authored by Yuxin Wu's avatar Yuxin Wu

fix multigpu training

parent 694e404b
...@@ -28,7 +28,6 @@ class DistributedReplicatedBuilder(DataParallelBuilder): ...@@ -28,7 +28,6 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
self.task_index = server_def.task_index self.task_index = server_def.task_index
# TODO XXX ps does't need to build! # TODO XXX ps does't need to build!
assert self.job_name in ['ps', 'worker'], self.job_name assert self.job_name in ['ps', 'worker'], self.job_name
assert tf.test.is_gpu_available()
logger.info("Distributed training on cluster:\n" + str(server_def.cluster)) logger.info("Distributed training on cluster:\n" + str(server_def.cluster))
logger.info("My role in the cluster: job={}, task={}".format(self.job_name, self.task_index)) logger.info("My role in the cluster: job={}, task={}".format(self.job_name, self.task_index))
...@@ -176,8 +175,8 @@ class DistributedReplicatedBuilder(DataParallelBuilder): ...@@ -176,8 +175,8 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
return grads return grads
# Ngpu * Nvar * 2 # Ngpu * Nvar * 2
grad_list = self.build_on_multi_tower( grad_list = DataParallelBuilder.build_on_towers(
get_grads, self.towers, get_grads,
devices=self.raw_devices, devices=self.raw_devices,
use_vs=[True] * len(self.towers)) # open vs at each tower use_vs=[True] * len(self.towers)) # open vs at each tower
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
......
...@@ -71,7 +71,7 @@ class DataParallelBuilder(GraphBuilder): ...@@ -71,7 +71,7 @@ class DataParallelBuilder(GraphBuilder):
self.towers = towers self.towers = towers
@staticmethod @staticmethod
def _check_tf_version(self): def _check_tf_version():
assert get_tf_version_number() >= 1.1, \ assert get_tf_version_number() >= 1.1, \
"TF version {} is too old to run multi GPU training!".format(tf.VERSION) "TF version {} is too old to run multi GPU training!".format(tf.VERSION)
...@@ -84,9 +84,12 @@ class DataParallelBuilder(GraphBuilder): ...@@ -84,9 +84,12 @@ class DataParallelBuilder(GraphBuilder):
nvars = [len(k) for k in grad_list] nvars = [len(k) for k in grad_list]
assert len(set(nvars)) == 1, "Number of gradients from each tower is different! " + str(nvars) assert len(set(nvars)) == 1, "Number of gradients from each tower is different! " + str(nvars)
def build_on_multi_tower( @staticmethod
self, func, devices=None, use_vs=None): def build_on_towers(
towers, func, devices=None, use_vs=None):
""" """
Run `func` on all towers.
Args: Args:
func: a lambda to be called inside each tower func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in ``towers``. devices: a list of devices to be used. By default will use GPUs in ``towers``.
...@@ -98,13 +101,13 @@ class DataParallelBuilder(GraphBuilder): ...@@ -98,13 +101,13 @@ class DataParallelBuilder(GraphBuilder):
ret = [] ret = []
if devices is not None: if devices is not None:
assert len(devices) == len(self.towers) assert len(devices) == len(towers)
if use_vs is not None: if use_vs is not None:
assert len(use_vs) == len(self.towers) assert len(use_vs) == len(towers)
tower_names = ['tower{}'.format(idx) for idx in range(len(self.towers))] tower_names = ['tower{}'.format(idx) for idx in range(len(towers))]
for idx, t in enumerate(self.towers): for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t) device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
usevs = use_vs[idx] if use_vs is not None else False usevs = use_vs[idx] if use_vs is not None else False
with tf.device(device), TowerContext( with tf.device(device), TowerContext(
...@@ -177,7 +180,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder): ...@@ -177,7 +180,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
grads = FilterNoneGrad().process(grads) grads = FilterNoneGrad().process(grads)
return grads return grads
grad_list = self.build_on_multi_tower(get_grads, devices) grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grads, devices)
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
# debug tower performance (without update): # debug tower performance (without update):
...@@ -237,7 +240,8 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -237,7 +240,8 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
grads = FilterNoneGrad().process(grads) grads = FilterNoneGrad().process(grads)
return grads return grads
grad_list = self.build_on_multi_tower( grad_list = DataParallelBuilder.build_on_towers(
self.towers,
get_grads, # use no variable scope for the first tower get_grads, # use no variable scope for the first tower
use_vs=[False] + [True] * (len(self.towers) - 1)) use_vs=[False] + [True] * (len(self.towers) - 1))
grads = SyncMultiGPUReplicatedBuilder._allreduce_grads(grad_list) grads = SyncMultiGPUReplicatedBuilder._allreduce_grads(grad_list)
...@@ -316,10 +320,10 @@ class AsyncMultiGPUBuilder(DataParallelBuilder): ...@@ -316,10 +320,10 @@ class AsyncMultiGPUBuilder(DataParallelBuilder):
grads = FilterNoneGrad().process(grads) grads = FilterNoneGrad().process(grads)
return grads return grads
grad_list = self.build_on_multi_tower(get_grads, devices) grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grads, devices)
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
if self.scale_gradient and len(self.towers) > 1: if self._scale_gradient and len(self.towers) > 1:
# pretend to average the grads, in order to make async and # pretend to average the grads, in order to make async and
# sync have consistent effective learning rate # sync have consistent effective learning rate
gradproc = ScaleGradient(('.*', 1.0 / len(self.towers)), verbose=False) gradproc = ScaleGradient(('.*', 1.0 / len(self.towers)), verbose=False)
......
...@@ -63,7 +63,7 @@ class DistributedTrainerReplicated(Trainer): ...@@ -63,7 +63,7 @@ class DistributedTrainerReplicated(Trainer):
assert config.data is not None and config.model is not None assert config.data is not None and config.model is not None
self.server = server self.server = server
self._builder = DistributedReplicatedBuilder(self.config.tower, server) self._builder = DistributedReplicatedBuilder(config.tower, server)
self._input_source = config.data self._input_source = config.data
......
...@@ -11,15 +11,25 @@ from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyC ...@@ -11,15 +11,25 @@ from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyC
from ..graph_builder.training import ( from ..graph_builder.training import (
SyncMultiGPUParameterServerBuilder, SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder, SyncMultiGPUReplicatedBuilder,
AsyncMultiGPUBuilder) AsyncMultiGPUBuilder,
DataParallelBuilder)
from .base import Trainer from .base import Trainer
__all__ = ['SyncMultiGPUTrainerReplicated', __all__ = ['MultiGPUTrainerBase',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer', 'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer', 'AsyncMultiGPUTrainer',
'SyncMultiGPUTrainer'] 'SyncMultiGPUTrainer']
class MultiGPUTrainerBase(Trainer):
"""
For backward compatibility only
"""
def build_on_multi_tower(towers, func, devices=None, use_vs=None):
DataParallelBuilder.build_on_towers(towers, func, devices, use_vs)
def apply_prefetch_policy(config, gpu_prefetch=True): def apply_prefetch_policy(config, gpu_prefetch=True):
assert (config.data is not None or config.dataflow is not None) and config.model is not None assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.data is None and config.dataflow is not None: if config.data is None and config.dataflow is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment