Commit 2388f4b8 authored by Yuxin Wu's avatar Yuxin Wu

Simplify messages of regularizers.

parent 02020381
...@@ -48,6 +48,8 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -48,6 +48,8 @@ def regularize_cost(regex, func, name='regularize_cost'):
params = ctx.filter_vars_by_vs_name(params) params = ctx.filter_vars_by_vs_name(params)
G = tf.get_default_graph() G = tf.get_default_graph()
to_regularize = []
with tf.name_scope('regularize_cost'): with tf.name_scope('regularize_cost'):
costs = [] costs = []
for p in params: for p in params:
...@@ -55,9 +57,23 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -55,9 +57,23 @@ def regularize_cost(regex, func, name='regularize_cost'):
if re.search(regex, para_name): if re.search(regex, para_name):
with G.colocate_with(p): with G.colocate_with(p):
costs.append(func(p)) costs.append(func(p))
_log_regularizer(para_name) to_regularize.append(para_name)
if not costs: if not costs:
return tf.constant(0, dtype=tf.float32, name='empty_' + name) return tf.constant(0, dtype=tf.float32, name='empty_' + name)
# remove tower prefix from names, and print
if len(ctx.vs_name):
prefix = ctx.vs_name + '/'
prefixlen = len(prefix)
def f(name):
if name.startswith(prefix):
return name[prefixlen:]
return name
to_regularize = list(map(f, to_regularize))
to_print = ', '.join(to_regularize)
_log_regularizer(to_print)
return tf.add_n(costs, name=name) return tf.add_n(costs, name=name)
......
...@@ -14,11 +14,11 @@ _CurrentTowerContext = None ...@@ -14,11 +14,11 @@ _CurrentTowerContext = None
class TowerContext(object): class TowerContext(object):
""" A context where the current model is being built in. """ """ A context where the current model is being built in. """
def __init__(self, tower_name, is_training=None, index=0, use_vs=False): def __init__(self, tower_name, is_training, index=0, use_vs=False):
""" """
Args: Args:
tower_name (str): The name scope of the tower. tower_name (str): The name scope of the tower.
is_training (bool): if None, automatically determine from tower_name. is_training (bool):
index (int): index of this tower, only used in training. index (int): index of this tower, only used in training.
use_vs (bool): Open a new variable scope with this name. use_vs (bool): Open a new variable scope with this name.
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment