Commit 4f4b1eee authored by Yuxin Wu's avatar Yuxin Wu

add collections option to internal summaries

parent ffd78d7e
......@@ -69,6 +69,9 @@ class ModelSaver(Callback):
keep_checkpoint_every_n_hours=self._keep_every_n_hours,
write_version=tf.train.SaverDef.V2,
save_relative_paths=True)
# Don't know how it can be useful,
# but since there is a predefined key, why not use it?
tf.add_to_collection(tf.GraphKeys.SAVERS, self.saver)
def _before_train(self):
# graph is finalized, OK to write it now.
......
......@@ -18,8 +18,8 @@ from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .tower import get_current_tower_context
from .symbolic_functions import rms
__all__ = ['create_scalar_summary', 'add_param_summary',
'add_activation_summary', 'add_moving_summary']
__all__ = ['create_scalar_summary', 'create_image_summary',
'add_param_summary', 'add_activation_summary', 'add_moving_summary']
def create_scalar_summary(name, v):
......@@ -71,12 +71,14 @@ def create_image_summary(name, val):
return s
def add_activation_summary(x, name=None):
def add_activation_summary(x, name=None, collections=None):
"""
Add summary for an activation tensor x. If name is None, use x.name.
Add summary for an activation tensor x, including
its sparsity, rms, and histogram.
Args:
x (tf.Tensor): the tensor to summary.
name (str): if is None, use x.name.
"""
ctx = get_current_tower_context()
if ctx is not None and not ctx.is_main_training_tower:
......@@ -88,12 +90,14 @@ def add_activation_summary(x, name=None):
if name is None:
name = x.name
with tf.name_scope('activation-summary'):
tf.summary.histogram(name, x)
tf.summary.scalar(name + '-sparsity', tf.nn.zero_fraction(x))
tf.summary.scalar(name + '-rms', rms(x))
tf.summary.histogram(name, x, collections=collections)
tf.summary.scalar(
name + '-sparsity', tf.nn.zero_fraction(x),
collections=collections)
tf.summary.scalar(name + '-rms', rms(x), collections=collections)
def add_param_summary(*summary_lists):
def add_param_summary(*summary_lists, collections=None):
"""
Add summary Ops for all trainable variables matching the regex.
......@@ -113,20 +117,24 @@ def add_param_summary(*summary_lists):
name = var.name.replace(':0', '')
if action == 'scalar':
assert ndim == 0, "Scalar summary on high-dimension data. Maybe you want 'mean'?"
tf.summary.scalar(name, var)
tf.summary.scalar(name, var, collections=collections)
return
assert ndim > 0, "Cannot perform {} summary on scalar data".format(action)
if action == 'histogram':
tf.summary.histogram(name, var)
tf.summary.histogram(name, var, collections=collections)
return
if action == 'sparsity':
tf.summary.scalar(name + '-sparsity', tf.nn.zero_fraction(var))
tf.summary.scalar(
name + '-sparsity', tf.nn.zero_fraction(var),
collections=collections)
return
if action == 'mean':
tf.summary.scalar(name + '-mean', tf.reduce_mean(var))
tf.summary.scalar(
name + '-mean', tf.reduce_mean(var),
collections=collections)
return
if action == 'rms':
tf.summary.scalar(name + '-rms', rms(var))
tf.summary.scalar(name + '-rms', rms(var), collections=collections)
return
raise RuntimeError("Unknown summary type: {}".format(action))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment