Commit 3f5b9d51 authored by Yuxin Wu's avatar Yuxin Wu

Get tower-owned variables through TowerContext

parent a397cebc
...@@ -43,7 +43,7 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay): ...@@ -43,7 +43,7 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
update_op2 = moving_averages.assign_moving_average( update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False, moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op') name='var_ema_op')
# Only add model var when we update them # Only add to model var when we update them
add_model_variable(moving_mean) add_model_variable(moving_mean)
add_model_variable(moving_var) add_model_variable(moving_var)
...@@ -143,7 +143,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -143,7 +143,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
xn = tf.nn.batch_normalization( xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon) x, moving_mean, moving_var, beta, gamma, epsilon)
# maintain EMA only on one GPU is OK. # maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay) ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else: else:
...@@ -231,8 +232,9 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5, ...@@ -231,8 +232,9 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
xn = tf.nn.batch_normalization( xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon) x, moving_mean, moving_var, beta, gamma, epsilon)
# training also needs EMA, so ideally we should maintain it on every tower # training also needs EMA, so we should maintain it as long as there are
if ctx.is_main_training_tower or ctx.has_own_variables: # corresponding EMA variables.
if ctx.has_own_variables:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay) ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else: else:
ret = tf.identity(xn, name='output') ret = tf.identity(xn, name='output')
......
...@@ -40,15 +40,15 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -40,15 +40,15 @@ def regularize_cost(regex, func, name='regularize_cost'):
cost = cost + regularize_cost("fc.*/W", l2_regularizer(1e-5)) cost = cost + regularize_cost("fc.*/W", l2_regularizer(1e-5))
""" """
ctx = get_current_tower_context() ctx = get_current_tower_context()
G = tf.get_default_graph() params = tf.trainable_variables()
params = G.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
# If vars are shared, use all of them
# If vars are replicated, only regularize those in the current tower
params = ctx.filter_vars_by_vs_name(params)
costs = [] costs = []
for p in params: for p in params:
para_name = p.name para_name = p.name
# in replicated mode, only regularize variables inside this tower
if ctx.has_own_variables and ctx.vs_name and (not para_name.startswith(ctx.vs_name)):
continue
if re.search(regex, para_name): if re.search(regex, para_name):
costs.append(func(p)) costs.append(func(p))
_log_regularizer(para_name) _log_regularizer(para_name)
......
...@@ -48,7 +48,7 @@ class TowerContext(object): ...@@ -48,7 +48,7 @@ class TowerContext(object):
@property @property
def has_own_variables(self): def has_own_variables(self):
return len(self._vs_name) > 0 return self.is_main_training_tower or len(self._vs_name) > 0
@property @property
def name(self): def name(self):
...@@ -60,6 +60,23 @@ class TowerContext(object): ...@@ -60,6 +60,23 @@ class TowerContext(object):
def vs_name(self): def vs_name(self):
return self._vs_name return self._vs_name
def filter_vars_by_vs_name(self, varlist):
"""
Filter the list and only keep those under the current variable scope.
If this tower doesn't contain its own variable scope, return the list as-is.
Args:
varlist (list[tf.Variable] or list[tf.Tensor]):
"""
if not self.has_own_variables:
return varlist
if len(self._vs_name) == 0:
# main_training_tower with no name. assume no other towers has
# been built yet, then varlist contains vars only in the first tower.
return varlist
prefix = self._vs_name + '/'
return [v for v in varlist if v.op.name.startswith(prefix)]
@property @property
def index(self): def index(self):
return self._index return self._index
......
...@@ -49,11 +49,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -49,11 +49,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
cost = self.model.get_cost() # assume single cost cost = self.model.get_cost() # assume single cost
# produce gradients # produce gradients
varlist = tf.trainable_variables() varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
if ctx is not None and ctx.has_own_variables and ctx.vs_name:
# only optimize w.r.t vars in this tower
# TODO use ctx.vars?
varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')]
grads = tf.gradients( grads = tf.gradients(
cost, cost,
varlist, varlist,
......
...@@ -286,7 +286,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain ...@@ -286,7 +286,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
@staticmethod @staticmethod
def get_post_init_ops(): def get_post_init_ops():
# Copy initialized values for variables on GPU 0 to other GPUs. # Copy initialized values for variables on GPU 0 to other GPUs.
all_vars = tf.trainable_variables() # TODO model_variables? all_vars = tf.trainable_variables()
all_vars.extend(tf.model_variables())
var_by_name = dict([(v.name, v) for v in all_vars]) var_by_name = dict([(v.name, v) for v in all_vars])
post_init_ops = [] post_init_ops = []
for v in all_vars: for v in all_vars:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment