Commit bbf41d9e authored by Yuxin Wu's avatar Yuxin Wu

use positional args instead of list for add_param_summary

parent 8f797c63
......@@ -134,8 +134,8 @@ class Model(ModelDesc):
self.cost = tf.truediv(symbf.huber_loss(target - pred_action_value),
tf.cast(BATCH_SIZE, tf.float32), name='cost')
summary.add_param_summary([('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms'])]) # monitor all W
summary.add_param_summary(('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms'])) # monitor all W
add_moving_summary(self.cost)
def update_target_param(self):
......
......@@ -157,7 +157,7 @@ class Model(ModelDesc):
# weight decay on all W of fc layers
wd_cost = regularize_cost('fc.*/W', l2_regularizer(5e-6))
add_param_summary([('.*/W', ['histogram', 'rms'])])
add_param_summary(('.*/W', ['histogram', 'rms']))
self.cost = tf.add_n([cost, wd_cost], name='cost')
add_moving_summary(cost, wd_cost, self.cost)
......
......@@ -122,7 +122,7 @@ class Model(ModelDesc):
# weight decay on all W of fc layers
wd_cost = regularize_cost('fc.*/W', l2_regularizer(1e-7))
add_param_summary([('.*/W', ['histogram', 'rms'])])
add_param_summary(('.*/W', ['histogram', 'rms']))
self.cost = tf.add_n([cost, wd_cost], name='cost')
add_moving_summary(cost, wd_cost, self.cost)
......
......@@ -89,7 +89,7 @@ class Model(ModelDesc):
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
costs.append(wd_cost)
add_param_summary([('.*/W', ['histogram'])]) # monitor W
add_param_summary(('.*/W', ['histogram'])) # monitor W
self.cost = tf.add_n(costs, name='cost')
add_moving_summary(costs + [wrong, self.cost])
......
......@@ -115,7 +115,7 @@ class Model(ModelDesc):
80000, 0.7, True)
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='l2_regularize_loss')
add_param_summary([('.*/W', ['histogram'])]) # monitor W
add_param_summary(('.*/W', ['histogram'])) # monitor W
self.cost = tf.add_n([cost, wd_cost], name='cost')
add_moving_summary(wd_cost, self.cost)
......
......@@ -103,7 +103,7 @@ class Model(ModelDesc):
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
add_moving_summary(cost, wd_cost)
add_param_summary([('.*/W', ['histogram'])]) # monitor W
add_param_summary(('.*/W', ['histogram'])) # monitor W
self.cost = tf.add_n([cost, wd_cost], name='cost')
......
......@@ -91,7 +91,7 @@ class Model(ModelDesc):
xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits, symbolic_functions.flatten(nextinput))
self.cost = tf.reduce_mean(xent_loss, name='cost')
summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
summary.add_param_summary(('.*/W', ['histogram'])) # monitor histogram of all W
summary.add_moving_summary(self.cost)
def get_gradient_processor(self):
......
......@@ -71,7 +71,7 @@ class Model(ModelDesc):
name='regularize_loss')
add_moving_summary(cost, wd_cost)
add_param_summary([('.*/W', ['histogram'])]) # monitor W
add_param_summary(('.*/W', ['histogram'])) # monitor W
self.cost = tf.add_n([cost, wd_cost], name='cost')
......
......@@ -107,9 +107,9 @@ class Model(ModelDesc):
summary.add_moving_summary(cost)
# monitor histogram of all weight (of conv and fc layers) in tensorboard
summary.add_param_summary([('.*/W', ['histogram', 'rms']),
('.*/weights', ['histogram', 'rms']) # to also work with slim
])
summary.add_param_summary(('.*/W', ['histogram', 'rms']),
('.*/weights', ['histogram', 'rms']) # to also work with slim
)
def get_data():
......
......@@ -57,7 +57,7 @@ class Model(ModelDesc):
wd_cost = regularize_cost('fc.*/W', l2_regularizer(0.00001))
add_moving_summary(cost, wd_cost)
add_param_summary([('.*/W', ['histogram', 'rms'])]) # monitor W
add_param_summary(('.*/W', ['histogram', 'rms'])) # monitor W
self.cost = tf.add_n([cost, wd_cost], name='cost')
......
......@@ -11,7 +11,7 @@ __all__ = ['describe_model', 'get_shape_str']
def describe_model():
""" print a description of the current model parameters """
""" Print a description of the current model parameters """
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
msg = [""]
total = 0
......@@ -29,8 +29,10 @@ def describe_model():
def get_shape_str(tensors):
"""
:param tensors: a tensor or a list of tensors
:returns: a string to describe the shape
Args:
tensors (list or tf.Tensor): a tensor or a list of tensors
Returns:
str: a string to describe the shape
"""
if isinstance(tensors, (list, tuple)):
for v in tensors:
......
......@@ -7,6 +7,7 @@ import tensorflow as tf
import re
from ..utils.argtools import memoized
from ..utils import logger
from ..utils.naming import MOVING_SUMMARY_VARS_KEY
from .tower import get_current_tower_context
from . import get_global_step_var
......@@ -18,7 +19,8 @@ __all__ = ['create_summary', 'add_param_summary', 'add_activation_summary',
def create_summary(name, v):
"""
Return a tf.Summary object with name and simple scalar value v
Returns:
tf.Summary: a tf.Summary object with name and simple scalar value v.
"""
assert isinstance(name, six.string_types), type(name)
v = float(v)
......@@ -29,8 +31,10 @@ def create_summary(name, v):
def add_activation_summary(x, name=None):
"""
Add summary to graph for an activation tensor x.
If name is None, use x.name.
Add summary for an activation tensor x. If name is None, use x.name.
Args:
x (tf.Tensor): the tensor to summary.
"""
ctx = get_current_tower_context()
if ctx is not None and not ctx.is_main_training_tower:
......@@ -47,16 +51,20 @@ def add_activation_summary(x, name=None):
tf.summary.scalar(name + '-rms', rms(x))
def add_param_summary(summary_lists):
def add_param_summary(*summary_lists):
"""
Add summary for all trainable variables matching the regex
Add summary Ops for all trainable variables matching the regex.
:param summary_lists: list of (regex, [list of summary type to perform]).
Type can be 'mean', 'scalar', 'histogram', 'sparsity', 'rms'
Args:
summary_lists (list): each is (regex, [list of summary type to perform]).
Summary type can be 'mean', 'scalar', 'histogram', 'sparsity', 'rms'
"""
ctx = get_current_tower_context()
if ctx is not None and not ctx.is_main_training_tower:
return
if len(summary_lists) == 0 and isinstance(summary_lists[0], list):
logger.warn("[Deprecated] Use positional args to call add_param_summary() instead of a list.")
summary_lists = summary_lists[0]
def perform(var, action):
ndim = var.get_shape().ndims
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment