Commit 08a5cf6f authored by Yuxin Wu's avatar Yuxin Wu

fix regularize collection for replicated mode

parent 2ecdbc00
......@@ -61,14 +61,19 @@ def regularize_cost(regex, func, name='regularize_cost'):
def regularize_cost_from_collection(name='regularize_cost'):
"""
Get the cost from the regularizers in ``tf.GraphKeys.REGULARIZATION_LOSSES``.
In replicated mode, will only regularize variables within the current tower.
Returns:
a scalar tensor, the regularization loss.
"""
regularization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
ctx = get_current_tower_context()
if len(regularization_losses) > 0:
# NOTE: this collection doesn't grow with towers.
# It is only added with variables that are newly created.
if ctx.has_own_variables: # be careful of the first tower (name='')
regularization_losses = ctx.filter_vars_by_vs_name(regularization_losses)
print([k.name for k in regularization_losses])
logger.info("Add REGULARIZATION_LOSSES of {} tensors on the total cost.".format(len(regularization_losses)))
reg_loss = tf.add_n(list(regularization_losses), name=name)
return reg_loss
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment