Commit 05bf948f authored by Yuxin Wu's avatar Yuxin Wu

Add "SimpleMovingAverage" callback

parent a0db5536
......@@ -211,6 +211,7 @@ class Callback(object):
def __str__(self):
return type(self).__name__
# TODO RENAME: same function to be used to get ops as well
def get_tensors_maybe_in_tower(self, names):
"""
Get tensors in the graph with the given names.
......
......@@ -4,12 +4,15 @@
import tensorflow as tf
import numpy as np
from collections import deque
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .base import Callback
__all__ = ['MovingAverageSummary', 'MergeAllSummaries']
__all__ = ['MovingAverageSummary', 'MergeAllSummaries', 'SimpleMovingAverage']
class MovingAverageSummary(Callback):
......@@ -17,7 +20,8 @@ class MovingAverageSummary(Callback):
This callback is enabled by default.
Maintain the moving average of summarized tensors in every step,
by ops added to the collection.
Note that it only maintains the EMAs, the actual summary should be done in other callbacks.
Note that it only __maintains__ the moving averages in the graph,
the actual summary should be done in other callbacks.
"""
def __init__(self, collection=MOVING_SUMMARY_OPS_KEY):
"""
......@@ -120,3 +124,41 @@ def MergeAllSummaries(period=0, run_alone=False, key=tf.GraphKeys.SUMMARIES):
return MergeAllSummaries_RunAlone(period, key)
else:
return MergeAllSummaries_RunWithOp(period, key)
class SimpleMovingAverage(Callback):
"""
Monitor Simple Moving Average (SMA), i.e. an average within a sliding window,
of some tensors.
"""
def __init__(self, tensors, window_size):
"""
Args:
tensors (str or [str]): names of tensors
window_size (int): size of the moving window
"""
self._tensors_names = [get_op_tensor_name(x)[1] for x in tensors]
self._display_names = [get_op_tensor_name(x)[0] for x in tensors]
self._window = int(window_size)
self._queue = deque(maxlen=window_size)
def _setup_graph(self):
tensors = self.get_tensors_maybe_in_tower(self._tensor_names)
for t in tensors:
assert t.get_shape().ndims == 0, \
"SimpleMovingAverage only accepts scalar tensor! Got one with {}".format(t.get_shape())
self._fetch = tf.train.SessionRunArgs(fetches=tensors)
def _before_run(self, _):
return self._fetch
def _after_run(self, _, rv):
results = rv.results
self._queue.append(results)
def _trigger_step(self):
if self.global_step % self._window == 0:
averages = np.asarray(self._queue).mean(axis=0)
for name, avg in zip(self._display_names, averages):
self.trainer.monitors.put_scalar(name + '/SMA', avg)
......@@ -11,7 +11,6 @@ from contextlib import contextmanager
from tensorflow.python.training import moving_averages
from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.argtools import graph_memoized
from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .tower import get_current_tower_context
......@@ -19,7 +18,8 @@ from .symbolic_functions import rms
from .scope_utils import cached_name_scope
__all__ = ['add_tensor_summary', 'add_param_summary',
'add_activation_summary', 'add_moving_summary']
'add_activation_summary', 'add_moving_summary',
]
# some scope stuff to use internally...
......@@ -196,6 +196,7 @@ def add_param_summary(*summary_lists, **kwargs):
add_tensor_summary(p, actions, name=name, collections=collections)
# TODO: collection for the summary op
def add_moving_summary(*args, **kwargs):
"""
Summarize the moving average for scalar tensors.
......@@ -224,24 +225,16 @@ def add_moving_summary(*args, **kwargs):
logger.warn("add_moving_summary() called under reuse=True scope, ignored.")
return []
if not isinstance(args[0], list):
v = args
else:
log_deprecated("Call add_moving_summary with positional args instead of a list!", eos="2018-02-28")
v = args[0]
for x in v:
for x in args:
assert isinstance(x, (tf.Tensor, tf.Variable)), x
assert x.get_shape().ndims == 0, \
"add_moving_summary() only accepts scalar tensor! Got one with {}".format(x.get_shape())
G = tf.get_default_graph()
# TODO variable not saved under distributed
ema_ops = []
for c in v:
for c in args:
name = re.sub('tower[0-9]+/', '', c.op.name)
# TODO colocate may affect distributed setting
# colocate variable with compute op implies that the variable should be local_vars
with G.colocate_with(c), tf.name_scope(None):
with tf.name_scope(None):
if not c.dtype.is_floating:
c = tf.cast(c, tf.float32)
# assign_moving_average creates variables with op names, therefore clear ns first.
......@@ -255,11 +248,9 @@ def add_moving_summary(*args, **kwargs):
zero_debias=True, name=name + '_EMA_apply')
ema_ops.append(ema_op)
with tf.name_scope(None):
# cannot add it into colocate group -- will force everything to cpus
tf.summary.scalar(name + '-summary', ema_op) # write the EMA value as a summary
if coll is not None:
for op in ema_ops:
# TODO a new collection to summary every step?
tf.add_to_collection(coll, op)
return ema_ops
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment