Commit dc31efa4 authored by Yuxin Wu's avatar Yuxin Wu

Change trigger_epoch to trigger in CallbackFactory

parent c223f223
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import tensorflow as tf import tensorflow as tf
from abc import ABCMeta from abc import ABCMeta
import six import six
from ..utils.develop import log_deprecated
from ..tfutils.common import get_op_or_tensor_by_name from ..tfutils.common import get_op_or_tensor_by_name
__all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable'] __all__ = ['Callback', 'ProxyCallback', 'CallbackFactory', 'Triggerable']
...@@ -255,17 +256,23 @@ class CallbackFactory(Callback): ...@@ -255,17 +256,23 @@ class CallbackFactory(Callback):
""" """
Create a callback with some lambdas. Create a callback with some lambdas.
""" """
def __init__(self, setup_graph=None, before_train=None, def __init__(self, setup_graph=None, before_train=None, trigger=None,
trigger_epoch=None, after_train=None): after_train=None, trigger_epoch=None):
""" """
Each lambda takes ``self`` as the only argument. Each lambda takes ``self`` as the only argument.
trigger_epoch was deprecated.
""" """
self._cb_setup_graph = setup_graph self._cb_setup_graph = setup_graph
self._cb_before_train = before_train self._cb_before_train = before_train
self._cb_trigger_epoch = trigger_epoch self._cb_trigger = trigger
self._cb_after_train = after_train self._cb_after_train = after_train
if trigger_epoch:
self._cb_trigger = trigger_epoch
log_deprecated("CallbackFactory(trigger_epoch=)", "Use trigger instead.", "2017-11-15")
def _setup_graph(self): def _setup_graph(self):
if self._cb_setup_graph: if self._cb_setup_graph:
self._cb_setup_graph(self) self._cb_setup_graph(self)
...@@ -274,9 +281,9 @@ class CallbackFactory(Callback): ...@@ -274,9 +281,9 @@ class CallbackFactory(Callback):
if self._cb_before_train: if self._cb_before_train:
self._cb_before_train(self) self._cb_before_train(self)
def _trigger_epoch(self): def _trigger(self):
if self._cb_trigger_epoch: if self._cb_trigger:
self._cb_trigger_epoch(self) self._cb_trigger(self)
def _after_train(self): def _after_train(self):
if self._cb_after_train: if self._cb_after_train:
......
...@@ -98,7 +98,3 @@ class Callbacks(Callback): ...@@ -98,7 +98,3 @@ class Callbacks(Callback):
def _after_epoch(self): def _after_epoch(self):
for cb in self.cbs: for cb in self.cbs:
cb.after_epoch() cb.after_epoch()
def append(self, cb):
assert isinstance(cb, Callback)
self.cbs.append(cb)
...@@ -163,13 +163,14 @@ def add_moving_summary(v, *args, **kwargs): ...@@ -163,13 +163,14 @@ def add_moving_summary(v, *args, **kwargs):
for x in v: for x in v:
assert isinstance(x, tf.Tensor), x assert isinstance(x, tf.Tensor), x
assert x.get_shape().ndims == 0, x.get_shape() assert x.get_shape().ndims == 0, x.get_shape()
gs = get_global_step_var()
# TODO will produce variable tower0/xxx? # TODO will produce variable tower0/xxx?
# TODO not saved under distributed # TODO not saved under distributed
# TODO use zero_debias # TODO use zero_debias
# TODO create EMA for each variable separately, so that the maintain ops # TODO create EMA for each variable separately, so that the maintain ops
# have a decent name (rather than EMA) # have a decent name (rather than EMA)
gs = get_global_step_var() # clear namescope, otherwise the variable names will have duplicated name scope
with tf.device(gs.device): with tf.name_scope(None), tf.device(gs.device):
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
decay, num_updates=gs, name='EMA') decay, num_updates=gs, name='EMA')
avg_maintain_op = averager.apply(v) avg_maintain_op = averager.apply(v)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment