Commit 3d9dc7d0 authored by Yuxin Wu's avatar Yuxin Wu

(fix #369)

parent a682083f
...@@ -147,7 +147,7 @@ class ModelDesc(ModelDescBase): ...@@ -147,7 +147,7 @@ class ModelDesc(ModelDescBase):
""" """
cost = self._get_cost() cost = self._get_cost()
reg_cost = regularize_cost_from_collection() reg_cost = regularize_cost_from_collection()
if reg_cost: if reg_cost is not None:
return tf.add(cost, reg_cost, name='cost_with_regularizer') return tf.add(cost, reg_cost, name='cost_with_regularizer')
else: else:
return cost return cost
......
...@@ -64,7 +64,7 @@ def regularize_cost_from_collection(name='regularize_cost'): ...@@ -64,7 +64,7 @@ def regularize_cost_from_collection(name='regularize_cost'):
In replicated mode, will only regularize variables within the current tower. In replicated mode, will only regularize variables within the current tower.
Returns: Returns:
a scalar tensor, the regularization loss, or 0 a scalar tensor, the regularization loss, or None
""" """
regularization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) regularization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
ctx = get_current_tower_context() ctx = get_current_tower_context()
...@@ -77,7 +77,7 @@ def regularize_cost_from_collection(name='regularize_cost'): ...@@ -77,7 +77,7 @@ def regularize_cost_from_collection(name='regularize_cost'):
reg_loss = tf.add_n(list(regularization_losses), name=name) reg_loss = tf.add_n(list(regularization_losses), name=name)
return reg_loss return reg_loss
else: else:
return 0 return None
@layer_register(log_shape=False, use_scope=False) @layer_register(log_shape=False, use_scope=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment