Commit a77cc508 authored by Yuxin Wu's avatar Yuxin Wu

remove vs_strategy from tower. Use vs_name in a cleaner way.

parent 4d2a7b4c
...@@ -14,20 +14,15 @@ _CurrentTowerContext = None ...@@ -14,20 +14,15 @@ _CurrentTowerContext = None
class TowerContext(object): class TowerContext(object):
""" A context where the current model is being built in. """ """ A context where the current model is being built in. """
def __init__(self, tower_name, def __init__(self, tower_name, is_training=None,
is_training=None, index=0, vs_name=''):
index=0,
var_strategy='shared',
vs_name=None):
""" """
Args: Args:
tower_name (str): The name scope of the tower. Currently used tower_name (str): The name scope of the tower. Currently used
values are like: 'tower0', 'towerp0', or '' values are like: 'tower0', 'towerp0', or ''
is_training (bool): if None, automatically determine from tower_name. is_training (bool): if None, automatically determine from tower_name.
index (int): index of this tower index (int): index of this tower
var_strategy (str): either 'shared' or 'replicated'. vs_name (str): Open a variable scope with this name, if given.
vs_name (str): the variable scope name to open. Only valid in
'replicated' mode. Defaults to be tower_name.
""" """
self._name = tower_name self._name = tower_name
...@@ -37,17 +32,7 @@ class TowerContext(object): ...@@ -37,17 +32,7 @@ class TowerContext(object):
self._index = int(index) self._index = int(index)
assert var_strategy in ['replicated', 'shared'], var_strategy
self._var_strategy = var_strategy
if self._var_strategy == 'replicated':
assert self._name
if vs_name is None:
self._vs_name = self._name
else:
self._vs_name = vs_name self._vs_name = vs_name
else:
assert vs_name is None, "vs_name is only valid in 'replicated' mode!"
self._vs_name = ''
@property @property
def is_main_training_tower(self): def is_main_training_tower(self):
...@@ -63,7 +48,7 @@ class TowerContext(object): ...@@ -63,7 +48,7 @@ class TowerContext(object):
@property @property
def has_own_variables(self): def has_own_variables(self):
return self._var_strategy == 'replicated' return len(self._vs_name) > 0
@property @property
def name(self): def name(self):
......
...@@ -70,22 +70,29 @@ class MultiGPUTrainerBase(Trainer): ...@@ -70,22 +70,29 @@ class MultiGPUTrainerBase(Trainer):
if devices is not None: if devices is not None:
assert len(devices) == len(towers) assert len(devices) == len(towers)
tower_names = ['tower{}'.format(idx) for idx in range(len(towers))]
keys_to_freeze = TOWER_FREEZE_KEYS[:] keys_to_freeze = TOWER_FREEZE_KEYS[:]
if var_strategy == 'replicated': # TODO ugly if var_strategy == 'replicated': # TODO ugly
logger.info("In replicated mode, UPDATE_OPS from all GPUs will be run.") logger.info("In replicated mode, UPDATE_OPS from all GPUs will be run.")
keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS) keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS)
# fix all Nones. TODO ugly
if vs_names is not None:
assert len(vs_names) == len(towers)
for idx, name in enumerate(vs_names):
if name is None:
vs_names[idx] = tower_names[idx]
else:
vs_names = tower_names
else: else:
assert vs_names is None assert vs_names is None
if vs_names is None: vs_names = [''] * len(towers)
vs_names = [None] * len(towers)
for idx, t in enumerate(towers): for idx, t in enumerate(towers):
device = devices[idx] if devices is not None else '/gpu:{}'.format(t) device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
with tf.device(device), TowerContext( with tf.device(device), TowerContext(
'tower{}'.format(idx), tower_names[idx],
is_training=True, is_training=True,
index=idx, index=idx,
var_strategy=var_strategy,
vs_name=vs_names[idx]): vs_name=vs_names[idx]):
if idx == t: if idx == t:
logger.info("Building graph for training tower {}...".format(idx)) logger.info("Building graph for training tower {}...".format(idx))
...@@ -279,17 +286,21 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -279,17 +286,21 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
@staticmethod @staticmethod
def get_post_init_ops(): def get_post_init_ops():
# Copy initialized values for variables on GPU 0 to other GPUs. # Copy initialized values for variables on GPU 0 to other GPUs.
global_vars = tf.global_variables() all_vars = tf.trainable_variables() # TODO model_variables?
var_by_name = dict([(v.name, v) for v in global_vars]) var_by_name = dict([(v.name, v) for v in all_vars])
post_init_ops = [] post_init_ops = []
for v in global_vars: for v in all_vars:
split_name = v.name.split('/') split_name = v.name.split('/')
if not v.name.startswith('tower'): if not v.name.startswith('tower'):
continue continue
# the master name doesn't have the towerx/ prefix if v.name.startswith('tower0'):
continue # TODO some vars (EMA) may still startswith tower0
# in this trainer, the master name doesn't have the towerx/ prefix
split_name = split_name[1:] split_name = split_name[1:]
copy_from = var_by_name['/'.join(split_name)] copy_from = var_by_name['/'.join(split_name)]
post_init_ops.append(v.assign(copy_from.read_value())) post_init_ops.append(v.assign(copy_from.read_value()))
logger.info(
"'sync_variables_from_tower0' includes {} operations.".format(len(post_init_ops)))
return tf.group(*post_init_ops, name='sync_variables_from_tower0') return tf.group(*post_init_ops, name='sync_variables_from_tower0')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment