Commit 365c56d2 authored by Yuxin Wu's avatar Yuxin Wu

fix ns in replicated mode

parent 8380cfa7
...@@ -40,10 +40,13 @@ def override_to_local_variable(enable=True): ...@@ -40,10 +40,13 @@ def override_to_local_variable(enable=True):
_replace_global_by_local(kwargs) _replace_global_by_local(kwargs)
return getter(name, *args, **kwargs) return getter(name, *args, **kwargs)
orig_vs = tf.get_variable_scope()
# TODO TF1.5 has https://github.com/tensorflow/tensorflow/pull/14390
with tf.variable_scope( with tf.variable_scope(
tf.get_variable_scope(), tf.get_variable_scope(),
custom_getter=custom_getter): custom_getter=custom_getter):
yield with tf.name_scope(orig_vs.original_name_scope):
yield
else: else:
yield yield
......
...@@ -107,7 +107,7 @@ def regularize_cost_from_collection(name='regularize_cost'): ...@@ -107,7 +107,7 @@ def regularize_cost_from_collection(name='regularize_cost'):
losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
if len(losses) > 0: if len(losses) > 0:
logger.info("Add REGULARIZATION_LOSSES of {} tensors on the total cost.".format(len(losses))) logger.info("Add REGULARIZATION_LOSSES of {} tensors on the total cost.".format(len(losses)))
reg_loss = tf.add_n(losses) reg_loss = tf.add_n(losses, name=name)
return reg_loss return reg_loss
else: else:
return None return None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment