Commit 26b4ea44 authored by Yuxin Wu's avatar Yuxin Wu

expose setup_graph in all multigpu trainers; TowerContext(use_vs=True) instead of passing a vs_name

parent 06bb5142
...@@ -14,23 +14,26 @@ _CurrentTowerContext = None ...@@ -14,23 +14,26 @@ _CurrentTowerContext = None
class TowerContext(object): class TowerContext(object):
""" A context where the current model is being built in. """ """ A context where the current model is being built in. """
def __init__(self, tower_name, is_training=None, index=0, vs_name=''): def __init__(self, tower_name, is_training=None, index=0, use_vs=False):
""" """
Args: Args:
tower_name (str): The name scope of the tower. tower_name (str): The name scope of the tower.
is_training (bool): if None, automatically determine from tower_name. is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower, only used in training. index (int): index of this tower, only used in training.
vs_name (str): Open a variable scope with this name, if given. use_vs (bool): Open a variable scope with this name.
""" """
self._name = tower_name self._name = tower_name
self._is_training = bool(is_training) self._is_training = bool(is_training)
if not self._is_training: if not self._is_training:
assert index == 0 and vs_name == '', \ assert index == 0 and not use_vs, \
"vs_name and index are only used in prediction!" "use_vs and index are only used in training!"
self._index = int(index) self._index = int(index)
self._vs_name = str(vs_name) if use_vs:
self._vs_name = self._name
else:
self._vs_name = ''
if self.has_own_variables: if self.has_own_variables:
assert not tf.get_variable_scope().reuse, "reuse=True in tower {}!".format(tower_name) assert not tf.get_variable_scope().reuse, "reuse=True in tower {}!".format(tower_name)
...@@ -96,8 +99,8 @@ class TowerContext(object): ...@@ -96,8 +99,8 @@ class TowerContext(object):
self._ctxs.append(tf.name_scope(self._name)) self._ctxs.append(tf.name_scope(self._name))
else: else:
if self.has_own_variables: if self.has_own_variables:
if len(self.vs_name): if len(self._vs_name):
self._ctxs.append(tf.variable_scope(self.vs_name)) self._ctxs.append(tf.variable_scope(self._vs_name))
else: else:
self._ctxs.append(tf.name_scope(self._name)) self._ctxs.append(tf.name_scope(self._name))
else: else:
......
...@@ -195,7 +195,7 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase): ...@@ -195,7 +195,7 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
self.model, self._input_source), self.model, self._input_source),
devices=self.raw_devices, devices=self.raw_devices,
var_strategy='replicated', var_strategy='replicated',
vs_names=None) # use the default vs names vs_names=[True] * self.config.nr_tower) # open vs at each tower
MultiGPUTrainerBase._check_grad_list(grad_list) MultiGPUTrainerBase._check_grad_list(grad_list)
avg_grads = DistributedReplicatedTrainer._average_grads(grad_list, self.raw_devices) avg_grads = DistributedReplicatedTrainer._average_grads(grad_list, self.raw_devices)
......
...@@ -49,14 +49,14 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase): ...@@ -49,14 +49,14 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
def build_on_multi_tower( def build_on_multi_tower(
towers, func, towers, func,
devices=None, var_strategy='shared', devices=None, var_strategy='shared',
vs_names=None): use_vs=None):
""" """
Args: Args:
towers: list of gpu relative ids towers: list of gpu relative ids
func: a lambda to be called inside each tower func: a lambda to be called inside each tower
devices: a list of devices to be used. By default will use GPUs in ``towers``. devices: a list of devices to be used. By default will use GPUs in ``towers``.
var_strategy (str): 'shared' or 'replicated' var_strategy (str): 'shared' or 'replicated'
vs_names (list[str]): list of variable scope names to use. use_vs (list[bool]): list of use_vs to passed to TowerContext
Returns: Returns:
List of outputs of ``func``, evaluated on each tower. List of outputs of ``func``, evaluated on each tower.
...@@ -74,17 +74,11 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase): ...@@ -74,17 +74,11 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
if var_strategy == 'replicated': # TODO ugly if var_strategy == 'replicated': # TODO ugly
logger.info("In replicated mode, UPDATE_OPS from all GPUs will be run.") logger.info("In replicated mode, UPDATE_OPS from all GPUs will be run.")
keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS) keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS)
# fix all Nones. TODO ugly
if vs_names is not None:
assert len(vs_names) == len(towers)
for idx, name in enumerate(vs_names):
if name is None:
vs_names[idx] = tower_names[idx]
else:
vs_names = tower_names
else: else:
assert vs_names is None assert use_vs is None
vs_names = [''] * len(towers) if use_vs is None:
use_vs = [False] * len(towers)
assert len(use_vs) == len(towers)
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t) device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
...@@ -92,7 +86,7 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase): ...@@ -92,7 +86,7 @@ class MultiGPUTrainerBase(FeedfreeTrainerBase):
tower_names[idx], tower_names[idx],
is_training=True, is_training=True,
index=idx, index=idx,
vs_name=vs_names[idx]): use_vs=use_vs[idx]):
if idx == t: if idx == t:
logger.info("Building graph for training tower {}...".format(idx)) logger.info("Building graph for training tower {}...".format(idx))
else: else:
...@@ -192,20 +186,32 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase): ...@@ -192,20 +186,32 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
new_tower_grads.append((grad, v)) new_tower_grads.append((grad, v))
return new_tower_grads return new_tower_grads
def _setup(self): @staticmethod
super(SyncMultiGPUTrainerParameterServer, self)._setup() def setup_graph(model, input, ps_device, tower):
"""
Args:
model (ModelDesc):
input (InputSource):
ps_device (str):
tower (list[int]):
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower] Returns:
if self._ps_device == 'gpu': tf.Operation: the training op
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower]
if ps_device == 'gpu':
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
else: else:
devices = [tf.train.replica_device_setter( devices = [tf.train.replica_device_setter(
worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices] worker_device=d, ps_device='/cpu:0', ps_tasks=1) for d in raw_devices]
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads( lambda: MultiGPUTrainerBase._build_graph_get_grads(model, input),
self.model, self._input_source), devices) devices)
MultiGPUTrainerBase._check_grad_list(grad_list) MultiGPUTrainerBase._check_grad_list(grad_list)
# debug tower performance (without update): # debug tower performance (without update):
...@@ -213,11 +219,16 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase): ...@@ -213,11 +219,16 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainerBase):
# self.train_op = tf.group(*ops) # self.train_op = tf.group(*ops)
# return # return
grads = self._average_grads(grad_list) grads = SyncMultiGPUTrainerParameterServer._average_grads(grad_list)
# grads = grad_list[0] # grads = grad_list[0]
self.train_op = self.model.get_optimizer().apply_gradients( train_op = model.get_optimizer().apply_gradients(grads, name='train_op')
grads, name='train_op') return train_op, input.get_callbacks()
def _setup(self):
self.train_op, cbs = SyncMultiGPUTrainerParameterServer.setup_graph(
self.model, self._input_source, self._ps_device, self.config.tower)
self.config.callbacks.extend(cbs)
def SyncMultiGPUTrainer(config): def SyncMultiGPUTrainer(config):
...@@ -266,31 +277,47 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase): ...@@ -266,31 +277,47 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase):
# NVar * NGPU * 2 # NVar * NGPU * 2
return new_tower_grads return new_tower_grads
def _setup(self): @staticmethod
super(SyncMultiGPUTrainerReplicated, self)._setup() def setup_graph(model, input, tower):
raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower] """
Args:
model (ModelDesc):
input (InputSource):
tower (list[int]):
Returns:
tf.Operation: the training op
[Callback]: the callbacks to be added
"""
input.setup(model.get_inputs_desc())
raw_devices = ['/gpu:{}'.format(k) for k in tower]
grad_list = MultiGPUTrainerBase.build_on_multi_tower( grad_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, tower,
lambda: MultiGPUTrainerBase._build_graph_get_grads( lambda: MultiGPUTrainerBase._build_graph_get_grads(model, input),
self.model, self._input_source),
var_strategy='replicated', var_strategy='replicated',
# use no variable scope for the first tower # use no variable scope for the first tower
vs_names=[''] + [None] * (self.config.nr_tower - 1)) use_vs=[False] + [True] * (len(tower) - 1))
grads = self._allreduce_grads(grad_list) grads = SyncMultiGPUTrainerReplicated._allreduce_grads(grad_list)
train_ops = [] train_ops = []
opt = self.model.get_optimizer() opt = model.get_optimizer()
for idx in range(self.config.nr_tower): for idx in range(len(tower)):
with tf.device(raw_devices[idx]): with tf.device(raw_devices[idx]):
grad_and_vars = [x[idx] for x in grads] grad_and_vars = [x[idx] for x in grads]
train_ops.append(opt.apply_gradients( train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(idx))) grad_and_vars, name='apply_grad_{}'.format(idx)))
self.train_op = tf.group(*train_ops, name='train_op') train_op = tf.group(*train_ops, name='train_op')
self.register_callback(RunOp( cb = RunOp(
SyncMultiGPUTrainerReplicated.get_post_init_ops, SyncMultiGPUTrainerReplicated.get_post_init_ops,
run_before=True, run_as_trigger=True, verbose=True)) run_before=True, run_as_trigger=True, verbose=True)
return train_op, input.get_callbacks() + [cb]
def _setup(self):
self.train_op, cbs = SyncMultiGPUTrainerReplicated.setup_graph(
self.model, self._input_source, self.config.tower)
self.config.callbacks.extend(cbs)
# Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py # Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
@staticmethod @staticmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment