Commit 2afe02ba authored by Yuxin Wu's avatar Yuxin Wu

make regularization work for REPLICATED

parent 731205ac
...@@ -64,13 +64,12 @@ def regularize_cost_from_collection(name='regularize_cost'): ...@@ -64,13 +64,12 @@ def regularize_cost_from_collection(name='regularize_cost'):
Returns: Returns:
a scalar tensor, the regularization loss. a scalar tensor, the regularization loss.
""" """
regulization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) regularization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
ctx = get_current_tower_context() if len(regularization_losses) > 0:
if len(regulization_losses) > 0: # NOTE: this collection doesn't grow with towers.
# TODO only regularize variables in this tower? # It is only added with variables that are newly created.
assert not ctx.has_own_variables, "REGULARIZATION_LOSSES collection doesn't work in replicated mode!" logger.info("Add REGULARIZATION_LOSSES of {} tensors on the total cost.".format(len(regularization_losses)))
logger.info("Apply REGULARIZATION_LOSSES on the total cost.") reg_loss = tf.add_n(list(regularization_losses), name=name)
reg_loss = tf.add_n(list(regulization_losses), name=name)
return reg_loss return reg_loss
else: else:
return tf.constant(0, dtype=tf.float32, name='empty_' + name) return tf.constant(0, dtype=tf.float32, name='empty_' + name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment