Commit c5e05d7a authored by Yuxin Wu's avatar Yuxin Wu

fix colocation problems

parent 5d529d03
...@@ -47,12 +47,14 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -47,12 +47,14 @@ def regularize_cost(regex, func, name='regularize_cost'):
# If vars are replicated, only regularize those in the current tower # If vars are replicated, only regularize those in the current tower
params = ctx.filter_vars_by_vs_name(params) params = ctx.filter_vars_by_vs_name(params)
G = tf.get_default_graph()
with tf.name_scope('regularize_cost'): with tf.name_scope('regularize_cost'):
costs = [] costs = []
for p in params: for p in params:
para_name = p.name para_name = p.name
if re.search(regex, para_name): if re.search(regex, para_name):
costs.append(func(p)) with G.colocate_with(p):
costs.append(func(p))
_log_regularizer(para_name) _log_regularizer(para_name)
if not costs: if not costs:
return tf.constant(0, dtype=tf.float32, name='empty_' + name) return tf.constant(0, dtype=tf.float32, name='empty_' + name)
......
...@@ -128,8 +128,7 @@ class MapGradient(GradientProcessor): ...@@ -128,8 +128,7 @@ class MapGradient(GradientProcessor):
for grad, var in grads: for grad, var in grads:
if re.match(self.regex, var.op.name): if re.match(self.regex, var.op.name):
matched = True matched = True
with tf.device(grad.device): grad = self.func(grad, var)
grad = self.func(grad, var)
if grad is not None: if grad is not None:
ret.append((grad, var)) ret.append((grad, var))
else: else:
......
...@@ -241,8 +241,9 @@ def add_moving_summary(*args, **kwargs): ...@@ -241,8 +241,9 @@ def add_moving_summary(*args, **kwargs):
ema_op = moving_averages.assign_moving_average( ema_op = moving_averages.assign_moving_average(
ema_var, c, decay, ema_var, c, decay,
zero_debias=True, name=name + '_EMA_apply') zero_debias=True, name=name + '_EMA_apply')
tf.summary.scalar(name + '-summary', ema_op) # write the EMA value as a summary
ema_ops.append(ema_op) ema_ops.append(ema_op)
# cannot add it into colocate group -- will force everything to cpus
tf.summary.scalar(name + '-summary', ema_op) # write the EMA value as a summary
if coll is not None: if coll is not None:
for op in ema_ops: for op in ema_ops:
# TODO a new collection to summary every step? # TODO a new collection to summary every step?
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment