Commit 3f5b9d51 authored by Yuxin Wu's avatar Yuxin Wu

Get tower-owned variables through TowerContext

parent a397cebc
......@@ -43,7 +43,7 @@ def update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay):
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
# Only add model var when we update them
# Only add to model var when we update them
add_model_variable(moving_mean)
add_model_variable(moving_var)
......@@ -143,7 +143,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
# maintain EMA only on one GPU is OK.
# maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA
if ctx.is_main_training_tower:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else:
......@@ -231,8 +232,9 @@ def BatchRenorm(x, rmax, dmax, decay=0.9, epsilon=1e-5,
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
# training also needs EMA, so ideally we should maintain it on every tower
if ctx.is_main_training_tower or ctx.has_own_variables:
# training also needs EMA, so we should maintain it as long as there are
# corresponding EMA variables.
if ctx.has_own_variables:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else:
ret = tf.identity(xn, name='output')
......
......@@ -40,15 +40,15 @@ def regularize_cost(regex, func, name='regularize_cost'):
cost = cost + regularize_cost("fc.*/W", l2_regularizer(1e-5))
"""
ctx = get_current_tower_context()
G = tf.get_default_graph()
params = G.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
params = tf.trainable_variables()
# If vars are shared, use all of them
# If vars are replicated, only regularize those in the current tower
params = ctx.filter_vars_by_vs_name(params)
costs = []
for p in params:
para_name = p.name
# in replicated mode, only regularize variables inside this tower
if ctx.has_own_variables and ctx.vs_name and (not para_name.startswith(ctx.vs_name)):
continue
if re.search(regex, para_name):
costs.append(func(p))
_log_regularizer(para_name)
......
......@@ -48,7 +48,7 @@ class TowerContext(object):
@property
def has_own_variables(self):
return len(self._vs_name) > 0
return self.is_main_training_tower or len(self._vs_name) > 0
@property
def name(self):
......@@ -60,6 +60,23 @@ class TowerContext(object):
def vs_name(self):
return self._vs_name
def filter_vars_by_vs_name(self, varlist):
"""
Filter the list and only keep those under the current variable scope.
If this tower doesn't contain its own variable scope, return the list as-is.
Args:
varlist (list[tf.Variable] or list[tf.Tensor]):
"""
if not self.has_own_variables:
return varlist
if len(self._vs_name) == 0:
# main_training_tower with no name. assume no other towers has
# been built yet, then varlist contains vars only in the first tower.
return varlist
prefix = self._vs_name + '/'
return [v for v in varlist if v.op.name.startswith(prefix)]
@property
def index(self):
return self._index
......
......@@ -49,11 +49,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
cost = self.model.get_cost() # assume single cost
# produce gradients
varlist = tf.trainable_variables()
if ctx is not None and ctx.has_own_variables and ctx.vs_name:
# only optimize w.r.t vars in this tower
# TODO use ctx.vars?
varlist = [v for v in varlist if v.op.name.startswith(ctx.vs_name + '/')]
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
grads = tf.gradients(
cost,
varlist,
......
......@@ -286,7 +286,8 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
@staticmethod
def get_post_init_ops():
# Copy initialized values for variables on GPU 0 to other GPUs.
all_vars = tf.trainable_variables() # TODO model_variables?
all_vars = tf.trainable_variables()
all_vars.extend(tf.model_variables())
var_by_name = dict([(v.name, v) for v in all_vars])
post_init_ops = []
for v in all_vars:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment