Commit ce709fa3 authored by Yuxin Wu's avatar Yuxin Wu

fix multigpu training

parent 694e404b
......@@ -28,7 +28,6 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
self.task_index = server_def.task_index
# TODO XXX ps does't need to build!
assert self.job_name in ['ps', 'worker'], self.job_name
assert tf.test.is_gpu_available()
logger.info("Distributed training on cluster:\n" + str(server_def.cluster))
logger.info("My role in the cluster: job={}, task={}".format(self.job_name, self.task_index))
......@@ -176,8 +175,8 @@ class DistributedReplicatedBuilder(DataParallelBuilder):
return grads
# Ngpu * Nvar * 2
grad_list = self.build_on_multi_tower(
get_grads,
grad_list = DataParallelBuilder.build_on_towers(
self.towers, get_grads,
devices=self.raw_devices,
use_vs=[True] * len(self.towers)) # open vs at each tower
DataParallelBuilder._check_grad_list(grad_list)
......
......@@ -71,7 +71,7 @@ class DataParallelBuilder(GraphBuilder):
self.towers = towers
@staticmethod
def _check_tf_version(self):
def _check_tf_version():
assert get_tf_version_number() >= 1.1, \
"TF version {} is too old to run multi GPU training!".format(tf.VERSION)
......@@ -84,9 +84,12 @@ class DataParallelBuilder(GraphBuilder):
nvars = [len(k) for k in grad_list]
assert len(set(nvars)) == 1, "Number of gradients from each tower is different! " + str(nvars)
def build_on_multi_tower(
self, func, devices=None, use_vs=None):
@staticmethod
def build_on_towers(
towers, func, devices=None, use_vs=None):
"""
Run `func` on all towers.
Args:
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in ``towers``.
......@@ -98,13 +101,13 @@ class DataParallelBuilder(GraphBuilder):
ret = []
if devices is not None:
assert len(devices) == len(self.towers)
assert len(devices) == len(towers)
if use_vs is not None:
assert len(use_vs) == len(self.towers)
assert len(use_vs) == len(towers)
tower_names = ['tower{}'.format(idx) for idx in range(len(self.towers))]
tower_names = ['tower{}'.format(idx) for idx in range(len(towers))]
for idx, t in enumerate(self.towers):
for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
usevs = use_vs[idx] if use_vs is not None else False
with tf.device(device), TowerContext(
......@@ -177,7 +180,7 @@ class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
grads = FilterNoneGrad().process(grads)
return grads
grad_list = self.build_on_multi_tower(get_grads, devices)
grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grads, devices)
DataParallelBuilder._check_grad_list(grad_list)
# debug tower performance (without update):
......@@ -237,7 +240,8 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
grads = FilterNoneGrad().process(grads)
return grads
grad_list = self.build_on_multi_tower(
grad_list = DataParallelBuilder.build_on_towers(
self.towers,
get_grads, # use no variable scope for the first tower
use_vs=[False] + [True] * (len(self.towers) - 1))
grads = SyncMultiGPUReplicatedBuilder._allreduce_grads(grad_list)
......@@ -316,10 +320,10 @@ class AsyncMultiGPUBuilder(DataParallelBuilder):
grads = FilterNoneGrad().process(grads)
return grads
grad_list = self.build_on_multi_tower(get_grads, devices)
grad_list = DataParallelBuilder.build_on_towers(self.towers, get_grads, devices)
DataParallelBuilder._check_grad_list(grad_list)
if self.scale_gradient and len(self.towers) > 1:
if self._scale_gradient and len(self.towers) > 1:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradproc = ScaleGradient(('.*', 1.0 / len(self.towers)), verbose=False)
......
......@@ -63,7 +63,7 @@ class DistributedTrainerReplicated(Trainer):
assert config.data is not None and config.model is not None
self.server = server
self._builder = DistributedReplicatedBuilder(self.config.tower, server)
self._builder = DistributedReplicatedBuilder(config.tower, server)
self._input_source = config.data
......
......@@ -11,15 +11,25 @@ from ..graph_builder.input_source import QueueInput, StagingInputWrapper, DummyC
from ..graph_builder.training import (
SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder,
AsyncMultiGPUBuilder)
AsyncMultiGPUBuilder,
DataParallelBuilder)
from .base import Trainer
__all__ = ['SyncMultiGPUTrainerReplicated',
__all__ = ['MultiGPUTrainerBase',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer',
'SyncMultiGPUTrainer']
class MultiGPUTrainerBase(Trainer):
"""
For backward compatibility only
"""
def build_on_multi_tower(towers, func, devices=None, use_vs=None):
DataParallelBuilder.build_on_towers(towers, func, devices, use_vs)
def apply_prefetch_policy(config, gpu_prefetch=True):
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.data is None and config.dataflow is not None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment