Commit 26b4ea44 authored by Yuxin Wu's avatar Yuxin Wu

expose setup_graph in all multigpu trainers; TowerContext(use_vs=True) instead of passing a vs_name

parent 06bb5142
......@@ -14,23 +14,26 @@ _CurrentTowerContext = None
class TowerContext(object):
""" A context where the current model is being built in. """
def __init__(self, tower_name, is_training=None, index=0, vs_name=''):
def __init__(self, tower_name, is_training=None, index=0, use_vs=False):
"""
Args:
tower_name (str): The name scope of the tower.
is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower, only used in training.
vs_name (str): Open a variable scope with this name, if given.
use_vs (bool): Open a variable scope with this name.
"""
self._name = tower_name
self._is_training = bool(is_training)
if not self._is_training:
assert index == 0 and vs_name == '', \
"vs_name and index are only used in prediction!"
assert index == 0 and not use_vs, \
"use_vs and index are only used in training!"
self._index = int(index)
self._vs_name = str(vs_name)
if use_vs:
self._vs_name = self._name
else:
self._vs_name = ''
if self.has_own_variables:
assert not tf.get_variable_scope().reuse, "reuse=True in tower {}!".format(tower_name)
......@@ -96,8 +99,8 @@ class TowerContext(object):
self._ctxs.append(tf.name_scope(self._name))
else:
if self.has_own_variables:
if len(self.vs_name):
self._ctxs.append(tf.variable_scope(self.vs_name))
if len(self._vs_name):
self._ctxs.append(tf.variable_scope(self._vs_name))
else:
self._ctxs.append(tf.name_scope(self._name))
else:
......
......@@ -195,7 +195,7 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
self.model, self._input_source),
devices=self.raw_devices,
var_strategy='replicated',
vs_names=None) # use the default vs names
vs_names=[True] * self.config.nr_tower) # open vs at each tower
MultiGPUTrainerBase._check_grad_list(grad_list)
avg_grads = DistributedReplicatedTrainer._average_grads(grad_list, self.raw_devices)
......
......@@ -49,14 +49,14 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
def build_on_multi_tower(
towers, func,
devices=None, var_strategy='shared',
vs_names=None):
use_vs=None):
"""
Args:
towers: list of gpu relative ids
func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in ``towers``.
var_strategy (str): 'shared' or 'replicated'
vs_names (list[str]): list of variable scope names to use.
use_vs (list[bool]): list of use_vs to passed to TowerContext
Returns:
List of outputs of ``func``, evaluated on each tower.
......@@ -74,17 +74,11 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
if var_strategy == 'replicated': # TODO ugly
logger.info("In replicated mode, UPDATE_OPS from all GPUs will be run.")
keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS)
# fix all Nones. TODO ugly
if vs_names is not None:
assert len(vs_names) == len(towers)
for idx, name in enumerate(vs_names):
if name is None:
vs_names[idx] = tower_names[idx]
else:
vs_names = tower_names
else:
assert vs_names is None
vs_names = [''] * len(towers)
assert use_vs is None
if use_vs is None:
use_vs = [False] * len(towers)
assert len(use_vs) == len(towers)
for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
......@@ -92,7 +86,7 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
tower_names[idx],
is_training=True,
index=idx,
vs_name=vs_names[idx]):
use_vs=use_vs[idx]):
if idx == t:
logger.info("Building graph for training tower {}...".format(idx))
else:
......@@ -192,20 +186,32 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
new_tower_grads.append((grad, v))
return new_tower_grads
def _setup(self):
super(SyncMultiGPUTrainerParameterServer, self)._setup()
@staticmethod
def setup_graph(model, input, ps_device, tower):
"""
Args:
model (ModelDesc):
input (InputSource):
ps_device (str):
tower (list[int]):
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
if self._ps_device == 'gpu':
Returns:
tf.Operation: the training op
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower]
if ps_device == 'gpu':
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
else:
devices = [tf.train.replica_device_setter(
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source), devices)
tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(model, input),
devices)
MultiGPUTrainerBase._check_grad_list(grad_list)
# debug tower performance (without update):
......@@ -213,11 +219,16 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
# self.train_op = tf.group(*ops)
# return
grads = self._average_grads(grad_list)
grads = SyncMultiGPUTrainerParameterServer._average_grads(grad_list)
# grads = grad_list[0]
self.train_op = self.model.get_optimizer().apply_gradients(
grads, name='train_op')
train_op = model.get_optimizer().apply_gradients(grads, name='train_op')
return train_op, input.get_callbacks()
def _setup(self):
self.train_op, cbs = SyncMultiGPUTrainerParameterServer.setup_graph(
self.model, self._input_source, self._ps_device, self.config.tower)
self.config.callbacks.extend(cbs)
def SyncMultiGPUTrainer(config):
......@@ -266,31 +277,47 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
# NVar * NGPU * 2
return new_tower_grads
def _setup(self):
super(SyncMultiGPUTrainerReplicated, self)._setup()
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
@staticmethod
def setup_graph(model, input, tower):
"""
Args:
model (ModelDesc):
input (InputSource):
tower (list[int]):
Returns:
tf.Operation: the training op
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower]
grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(
self.model, self._input_source),
tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads(model, input),
var_strategy='replicated',
# use no variable scope for the first tower
vs_names=[''] + [None] * (self.config.nr_tower - 1))
grads = self._allreduce_grads(grad_list)
use_vs=[False] + [True] * (len(tower) - 1))
grads = SyncMultiGPUTrainerReplicated._allreduce_grads(grad_list)
train_ops = []
opt = self.model.get_optimizer()
for idx in range(self.config.nr_tower):
opt = model.get_optimizer()
for idx in range(len(tower)):
with tf.device(raw_devices[idx]):
grad_and_vars = [x[idx] for x in grads]
train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(idx)))
self.train_op = tf.group(*train_ops, name='train_op')
self.register_callback(RunOp(
train_op = tf.group(*train_ops, name='train_op')
cb = RunOp(
SyncMultiGPUTrainerReplicated.get_post_init_ops,
run_before=True, run_as_trigger=True, verbose=True))
run_before=True, run_as_trigger=True, verbose=True)
return train_op, input.get_callbacks() + [cb]
def _setup(self):
self.train_op, cbs = SyncMultiGPUTrainerReplicated.setup_graph(
self.model, self._input_source, self.config.tower)
self.config.callbacks.extend(cbs)
# Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
@staticmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment