Commit a77cc508 authored by Yuxin Wu's avatar Yuxin Wu

remove vs_strategy from tower. Use vs_name in a cleaner way.

parent 4d2a7b4c
......@@ -14,20 +14,15 @@ _CurrentTowerContext = None
class TowerContext(object):
""" A context where the current model is being built in. """
def __init__(self, tower_name,
is_training=None,
index=0,
var_strategy='shared',
vs_name=None):
def __init__(self, tower_name, is_training=None,
index=0, vs_name=''):
"""
Args:
tower_name (str): The name scope of the tower. Currently used
values are like: 'tower0', 'towerp0', or ''
is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower
var_strategy (str): either 'shared' or 'replicated'.
vs_name (str): the variable scope name to open. Only valid in
'replicated' mode. Defaults to be tower_name.
vs_name (str): Open a variable scope with this name, if given.
"""
self._name = tower_name
......@@ -37,17 +32,7 @@ class TowerContext(object):
self._index = int(index)
assert var_strategy in ['replicated', 'shared'], var_strategy
self._var_strategy = var_strategy
if self._var_strategy == 'replicated':
assert self._name
if vs_name is None:
self._vs_name = self._name
else:
self._vs_name = vs_name
else:
assert vs_name is None, "vs_name is only valid in 'replicated' mode!"
self._vs_name = ''
@property
def is_main_training_tower(self):
......@@ -63,7 +48,7 @@ class TowerContext(object):
@property
def has_own_variables(self):
return self._var_strategy == 'replicated'
return len(self._vs_name) > 0
@property
def name(self):
......
......@@ -70,22 +70,29 @@ class MultiGPUTrainerBase(Trainer):
if devices is not None:
assert len(devices) == len(towers)
tower_names = ['tower{}'.format(idx) for idx in range(len(towers))]
keys_to_freeze = TOWER_FREEZE_KEYS[:]
if var_strategy == 'replicated': # TODO ugly
logger.info("In replicated mode, UPDATE_OPS from all GPUs will be run.")
keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS)
# fix all Nones. TODO ugly
if vs_names is not None:
assert len(vs_names) == len(towers)
for idx, name in enumerate(vs_names):
if name is None:
vs_names[idx] = tower_names[idx]
else:
vs_names = tower_names
else:
assert vs_names is None
if vs_names is None:
vs_names = [None] * len(towers)
vs_names = [''] * len(towers)
for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
with tf.device(device), TowerContext(
'tower{}'.format(idx),
tower_names[idx],
is_training=True,
index=idx,
var_strategy=var_strategy,
vs_name=vs_names[idx]):
if idx == t:
logger.info("Building graph for training tower {}...".format(idx))
......@@ -279,17 +286,21 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
@staticmethod
def get_post_init_ops():
# Copy initialized values for variables on GPU 0 to other GPUs.
global_vars = tf.global_variables()
var_by_name = dict([(v.name, v) for v in global_vars])
all_vars = tf.trainable_variables() # TODO model_variables?
var_by_name = dict([(v.name, v) for v in all_vars])
post_init_ops = []
for v in global_vars:
for v in all_vars:
split_name = v.name.split('/')
if not v.name.startswith('tower'):
continue
# the master name doesn't have the towerx/ prefix
if v.name.startswith('tower0'):
continue # TODO some vars (EMA) may still startswith tower0
# in this trainer, the master name doesn't have the towerx/ prefix
split_name = split_name[1:]
copy_from = var_by_name['/'.join(split_name)]
post_init_ops.append(v.assign(copy_from.read_value()))
logger.info(
"'sync_variables_from_tower0' includes {} operations.".format(len(post_init_ops)))
return tf.group(*post_init_ops, name='sync_variables_from_tower0')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment