Commit bd5e0591 authored by Yuxin Wu's avatar Yuxin Wu

move slim regularize cost to regularize.py

parent 4ebdd71d
...@@ -11,9 +11,9 @@ import six ...@@ -11,9 +11,9 @@ import six
from ..utils import logger from ..utils import logger
from ..utils.naming import INPUTS_KEY from ..utils.naming import INPUTS_KEY
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..tfutils.model_utils import apply_slim_collections from .regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'InputVar', 'ModelDesc', 'ModelFromMetaGraph'] __all__ = ['InputDesc', 'InputVar', 'ModelDesc']
class InputDesc(object): class InputDesc(object):
...@@ -119,10 +119,10 @@ class ModelDesc(object): ...@@ -119,10 +119,10 @@ class ModelDesc(object):
This function also applies the collection This function also applies the collection
``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically. ``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically.
Because slim users would expect the regularizer being automatically applied once used in slim layers.
""" """
cost = self._get_cost() cost = self._get_cost()
return apply_slim_collections(cost) return tf.add(cost, regularize_cost_from_collection(),
name='cost_with_regularizer')
def _get_cost(self, *args): def _get_cost(self, *args):
return self.cost return self.cost
...@@ -158,7 +158,7 @@ class ModelFromMetaGraph(ModelDesc): ...@@ -158,7 +158,7 @@ class ModelFromMetaGraph(ModelDesc):
Only useful for inference. Only useful for inference.
""" """
# TODO this class may not be functional anymore. # TODO this class may not be functional anymore. don't use
def __init__(self, filename): def __init__(self, filename):
""" """
......
...@@ -53,10 +53,29 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -53,10 +53,29 @@ def regularize_cost(regex, func, name='regularize_cost'):
costs.append(func(p)) costs.append(func(p))
_log_regularizer(para_name) _log_regularizer(para_name)
if not costs: if not costs:
return tf.constant(0, dtype=tf.float32, name='empty_regularize_cost') return tf.constant(0, dtype=tf.float32, name='empty_' + name)
return tf.add_n(costs, name=name) return tf.add_n(costs, name=name)
def regularize_cost_from_collection(name='regularize_cost'):
"""
Get the cost from the regularizers in ``tf.GraphKeys.REGULARIZATION_LOSSES``.
Returns:
a scalar tensor, the regularization loss.
"""
regulization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
ctx = get_current_tower_context()
if len(regulization_losses) > 0:
# TODO only regularize variables in this tower?
assert not ctx.has_own_variables, "REGULARIZATION_LOSSES collection doesn't work in replicated mode!"
logger.info("Apply REGULARIZATION_LOSSES on the total cost.")
reg_loss = tf.add_n(list(regulization_losses), name=name)
return reg_loss
else:
return tf.constant(0, dtype=tf.float32, name='empty_' + name)
@layer_register(log_shape=False, use_scope=False) @layer_register(log_shape=False, use_scope=False)
def Dropout(x, keep_prob=0.5, is_training=None, noise_shape=None): def Dropout(x, keep_prob=0.5, is_training=None, noise_shape=None):
""" """
......
...@@ -6,11 +6,9 @@ import tensorflow as tf ...@@ -6,11 +6,9 @@ import tensorflow as tf
from termcolor import colored from termcolor import colored
from tabulate import tabulate from tabulate import tabulate
from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger
from .summary import add_moving_summary
__all__ = ['describe_model', 'get_shape_str', 'apply_slim_collections'] __all__ = ['describe_model', 'get_shape_str']
def describe_model(): def describe_model():
...@@ -65,24 +63,3 @@ def get_shape_str(tensors): ...@@ -65,24 +63,3 @@ def get_shape_str(tensors):
assert isinstance(tensors, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(tensors)) assert isinstance(tensors, (tf.Tensor, tf.Variable)), "Not a tensor: {}".format(type(tensors))
shape_str = str(tensors.get_shape().as_list()) shape_str = str(tensors.get_shape().as_list())
return shape_str return shape_str
def apply_slim_collections(cost):
"""
Add the cost with the regularizers in ``tf.GraphKeys.REGULARIZATION_LOSSES``.
Args:
cost: a scalar tensor
Return:
a scalar tensor, the cost after applying the collections.
"""
regulization_losses = set(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
ctx = get_current_tower_context()
if len(regulization_losses) > 0:
assert not ctx.has_own_variables, "REGULARIZATION_LOSSES collection doesn't work in replicated mode!"
logger.info("Applying REGULARIZATION_LOSSES on cost.")
reg_loss = tf.add_n(list(regulization_losses), name="regularize_loss")
cost = tf.add(reg_loss, cost, name='total_cost')
add_moving_summary(reg_loss, cost)
return cost
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment