Commit 365c56d2 authored by Yuxin Wu's avatar Yuxin Wu

fix ns in replicated mode

parent 8380cfa7
......@@ -40,9 +40,12 @@ def override_to_local_variable(enable=True):
_replace_global_by_local(kwargs)
return getter(name, *args, **kwargs)
orig_vs = tf.get_variable_scope()
# TODO TF1.5 has https://github.com/tensorflow/tensorflow/pull/14390
with tf.variable_scope(
tf.get_variable_scope(),
custom_getter=custom_getter):
with tf.name_scope(orig_vs.original_name_scope):
yield
else:
yield
......
......@@ -107,7 +107,7 @@ def regularize_cost_from_collection(name='regularize_cost'):
losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
if len(losses) > 0:
logger.info("Add REGULARIZATION_LOSSES of {} tensors on the total cost.".format(len(losses)))
reg_loss = tf.add_n(losses)
reg_loss = tf.add_n(losses, name=name)
return reg_loss
else:
return None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment