Commit dc31efa4 authored by Yuxin Wu's avatar Yuxin Wu

Change trigger_epoch to trigger in CallbackFactory

parent c223f223
......@@ -5,6 +5,7 @@
import tensorflow as tf
from abc import ABCMeta
import six
from ..utils.develop import log_deprecated
from ..tfutils.common import get_op_or_tensor_by_name
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable']
......@@ -255,17 +256,23 @@ class CallbackFactory(Callback):
"""
Create a callback with some lambdas.
"""
def __init__(self, setup_graph=None, before_train=None,
trigger_epoch=None, after_train=None):
def __init__(self, setup_graph=None, before_train=None, trigger=None,
after_train=None, trigger_epoch=None):
"""
Each lambda takes ``self`` as the only argument.
trigger_epoch was deprecated.
"""
self._cb_setup_graph = setup_graph
self._cb_before_train = before_train
self._cb_trigger_epoch = trigger_epoch
self._cb_trigger = trigger
self._cb_after_train = after_train
if trigger_epoch:
self._cb_trigger = trigger_epoch
log_deprecated("CallbackFactory(trigger_epoch=)", "Use trigger instead.", "2017-11-15")
def _setup_graph(self):
if self._cb_setup_graph:
self._cb_setup_graph(self)
......@@ -274,9 +281,9 @@ class CallbackFactory(Callback):
if self._cb_before_train:
self._cb_before_train(self)
def _trigger_epoch(self):
if self._cb_trigger_epoch:
self._cb_trigger_epoch(self)
def _trigger(self):
if self._cb_trigger:
self._cb_trigger(self)
def _after_train(self):
if self._cb_after_train:
......
......@@ -98,7 +98,3 @@ class Callbacks(Callback):
def _after_epoch(self):
for cb in self.cbs:
cb.after_epoch()
def append(self, cb):
assert isinstance(cb, Callback)
self.cbs.append(cb)
......@@ -163,13 +163,14 @@ def add_moving_summary(v, *args, **kwargs):
for x in v:
assert isinstance(x, tf.Tensor), x
assert x.get_shape().ndims == 0, x.get_shape()
gs = get_global_step_var()
# TODO will produce variable tower0/xxx?
# TODO not saved under distributed
# TODO use zero_debias
# TODO create EMA for each variable separately, so that the maintain ops
# have a decent name (rather than EMA)
gs = get_global_step_var()
with tf.device(gs.device):
# clear namescope, otherwise the variable names will have duplicated name scope
with tf.name_scope(None), tf.device(gs.device):
averager = tf.train.ExponentialMovingAverage(
decay, num_updates=gs, name='EMA')
avg_maintain_op = averager.apply(v)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment