Commit 3d9dc7d0 authored by Yuxin Wu's avatar Yuxin Wu

(fix #369)

parent a682083f
......@@ -147,7 +147,7 @@ class ModelDesc(ModelDescBase):
"""
cost = self._get_cost()
reg_cost = regularize_cost_from_collection()
if reg_cost:
if reg_cost is not None:
return tf.add(cost, reg_cost, name='cost_with_regularizer')
else:
return cost
......
......@@ -64,7 +64,7 @@ def regularize_cost_from_collection(name='regularize_cost'):
In replicated mode, will only regularize variables within the current tower.
Returns:
a scalar tensor, the regularization loss, or 0
a scalar tensor, the regularization loss, or None
"""
regularization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
ctx = get_current_tower_context()
......@@ -77,7 +77,7 @@ def regularize_cost_from_collection(name='regularize_cost'):
reg_loss = tf.add_n(list(regularization_losses), name=name)
return reg_loss
else:
return 0
return None
@layer_register(log_shape=False, use_scope=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment