Commit 05bf948f authored by Yuxin Wu's avatar Yuxin Wu

Add "SimpleMovingAverage" callback

parent a0db5536
...@@ -211,6 +211,7 @@ class Callback(object): ...@@ -211,6 +211,7 @@ class Callback(object):
def __str__(self): def __str__(self):
return type(self).__name__ return type(self).__name__
# TODO RENAME: same function to be used to get ops as well
def get_tensors_maybe_in_tower(self, names): def get_tensors_maybe_in_tower(self, names):
""" """
Get tensors in the graph with the given names. Get tensors in the graph with the given names.
......
...@@ -4,12 +4,15 @@ ...@@ -4,12 +4,15 @@
import tensorflow as tf import tensorflow as tf
import numpy as np
from collections import deque
from ..tfutils.common import get_op_tensor_name
from ..utils import logger from ..utils import logger
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .base import Callback from .base import Callback
__all__ = ['MovingAverageSummary', 'MergeAllSummaries'] __all__ = ['MovingAverageSummary', 'MergeAllSummaries', 'SimpleMovingAverage']
class MovingAverageSummary(Callback): class MovingAverageSummary(Callback):
...@@ -17,7 +20,8 @@ class MovingAverageSummary(Callback): ...@@ -17,7 +20,8 @@ class MovingAverageSummary(Callback):
This callback is enabled by default. This callback is enabled by default.
Maintain the moving average of summarized tensors in every step, Maintain the moving average of summarized tensors in every step,
by ops added to the collection. by ops added to the collection.
Note that it only maintains the EMAs, the actual summary should be done in other callbacks. Note that it only __maintains__ the moving averages in the graph,
the actual summary should be done in other callbacks.
""" """
def __init__(self, collection=MOVING_SUMMARY_OPS_KEY): def __init__(self, collection=MOVING_SUMMARY_OPS_KEY):
""" """
...@@ -120,3 +124,41 @@ def MergeAllSummaries(period=0, run_alone=False, key=tf.GraphKeys.SUMMARIES): ...@@ -120,3 +124,41 @@ def MergeAllSummaries(period=0, run_alone=False, key=tf.GraphKeys.SUMMARIES):
return MergeAllSummaries_RunAlone(period, key) return MergeAllSummaries_RunAlone(period, key)
else: else:
return MergeAllSummaries_RunWithOp(period, key) return MergeAllSummaries_RunWithOp(period, key)
class SimpleMovingAverage(Callback):
"""
Monitor Simple Moving Average (SMA), i.e. an average within a sliding window,
of some tensors.
"""
def __init__(self, tensors, window_size):
"""
Args:
tensors (str or [str]): names of tensors
window_size (int): size of the moving window
"""
self._tensors_names = [get_op_tensor_name(x)[1] for x in tensors]
self._display_names = [get_op_tensor_name(x)[0] for x in tensors]
self._window = int(window_size)
self._queue = deque(maxlen=window_size)
def _setup_graph(self):
tensors = self.get_tensors_maybe_in_tower(self._tensor_names)
for t in tensors:
assert t.get_shape().ndims == 0, \
"SimpleMovingAverage only accepts scalar tensor! Got one with {}".format(t.get_shape())
self._fetch = tf.train.SessionRunArgs(fetches=tensors)
def _before_run(self, _):
return self._fetch
def _after_run(self, _, rv):
results = rv.results
self._queue.append(results)
def _trigger_step(self):
if self.global_step % self._window == 0:
averages = np.asarray(self._queue).mean(axis=0)
for name, avg in zip(self._display_names, averages):
self.trainer.monitors.put_scalar(name + '/SMA', avg)
...@@ -11,7 +11,6 @@ from contextlib import contextmanager ...@@ -11,7 +11,6 @@ from contextlib import contextmanager
from tensorflow.python.training import moving_averages from tensorflow.python.training import moving_averages
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.argtools import graph_memoized from ..utils.argtools import graph_memoized
from ..utils.naming import MOVING_SUMMARY_OPS_KEY from ..utils.naming import MOVING_SUMMARY_OPS_KEY
from .tower import get_current_tower_context from .tower import get_current_tower_context
...@@ -19,7 +18,8 @@ from .symbolic_functions import rms ...@@ -19,7 +18,8 @@ from .symbolic_functions import rms
from .scope_utils import cached_name_scope from .scope_utils import cached_name_scope
__all__ = ['add_tensor_summary', 'add_param_summary', __all__ = ['add_tensor_summary', 'add_param_summary',
'add_activation_summary', 'add_moving_summary'] 'add_activation_summary', 'add_moving_summary',
]
# some scope stuff to use internally... # some scope stuff to use internally...
...@@ -196,6 +196,7 @@ def add_param_summary(*summary_lists, **kwargs): ...@@ -196,6 +196,7 @@ def add_param_summary(*summary_lists, **kwargs):
add_tensor_summary(p, actions, name=name, collections=collections) add_tensor_summary(p, actions, name=name, collections=collections)
# TODO: collection for the summary op
def add_moving_summary(*args, **kwargs): def add_moving_summary(*args, **kwargs):
""" """
Summarize the moving average for scalar tensors. Summarize the moving average for scalar tensors.
...@@ -224,24 +225,16 @@ def add_moving_summary(*args, **kwargs): ...@@ -224,24 +225,16 @@ def add_moving_summary(*args, **kwargs):
logger.warn("add_moving_summary() called under reuse=True scope, ignored.") logger.warn("add_moving_summary() called under reuse=True scope, ignored.")
return [] return []
if not isinstance(args[0], list): for x in args:
v = args
else:
log_deprecated("Call add_moving_summary with positional args instead of a list!", eos="2018-02-28")
v = args[0]
for x in v:
assert isinstance(x, (tf.Tensor, tf.Variable)), x assert isinstance(x, (tf.Tensor, tf.Variable)), x
assert x.get_shape().ndims == 0, \ assert x.get_shape().ndims == 0, \
"add_moving_summary() only accepts scalar tensor! Got one with {}".format(x.get_shape()) "add_moving_summary() only accepts scalar tensor! Got one with {}".format(x.get_shape())
G = tf.get_default_graph()
# TODO variable not saved under distributed # TODO variable not saved under distributed
ema_ops = [] ema_ops = []
for c in v: for c in args:
name = re.sub('tower[0-9]+/', '', c.op.name) name = re.sub('tower[0-9]+/', '', c.op.name)
# TODO colocate may affect distributed setting with tf.name_scope(None):
# colocate variable with compute op implies that the variable should be local_vars
with G.colocate_with(c), tf.name_scope(None):
if not c.dtype.is_floating: if not c.dtype.is_floating:
c = tf.cast(c, tf.float32) c = tf.cast(c, tf.float32)
# assign_moving_average creates variables with op names, therefore clear ns first. # assign_moving_average creates variables with op names, therefore clear ns first.
...@@ -255,11 +248,9 @@ def add_moving_summary(*args, **kwargs): ...@@ -255,11 +248,9 @@ def add_moving_summary(*args, **kwargs):
zero_debias=True, name=name + '_EMA_apply') zero_debias=True, name=name + '_EMA_apply')
ema_ops.append(ema_op) ema_ops.append(ema_op)
with tf.name_scope(None): with tf.name_scope(None):
# cannot add it into colocate group -- will force everything to cpus
tf.summary.scalar(name + '-summary', ema_op) # write the EMA value as a summary tf.summary.scalar(name + '-summary', ema_op) # write the EMA value as a summary
if coll is not None: if coll is not None:
for op in ema_ops: for op in ema_ops:
# TODO a new collection to summary every step?
tf.add_to_collection(coll, op) tf.add_to_collection(coll, op)
return ema_ops return ema_ops
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment