Commit 2afe02ba authored by Yuxin Wu's avatar Yuxin Wu

make regularization work for REPLICATED

parent 731205ac
......@@ -64,13 +64,12 @@ def regularize_cost_from_collection(name='regularize_cost'):
Returns:
a scalar tensor, the regularization loss.
"""
regulization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
ctx = get_current_tower_context()
if len(regulization_losses) > 0:
# TODO only regularize variables in this tower?
assert not ctx.has_own_variables, "REGULARIZATION_LOSSES collection doesn't work in replicated mode!"
logger.info("Apply REGULARIZATION_LOSSES on the total cost.")
reg_loss = tf.add_n(list(regulization_losses), name=name)
regularization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
if len(regularization_losses) > 0:
# NOTE: this collection doesn't grow with towers.
# It is only added with variables that are newly created.
logger.info("Add REGULARIZATION_LOSSES of {} tensors on the total cost.".format(len(regularization_losses)))
reg_loss = tf.add_n(list(regularization_losses), name=name)
return reg_loss
else:
return tf.constant(0, dtype=tf.float32, name='empty_' + name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment