Commit bbe8f42a authored by Yuxin Wu's avatar Yuxin Wu

remove test callback. add .print in linearwrap

parent 7a110067
...@@ -102,7 +102,6 @@ class Model(ModelDesc): ...@@ -102,7 +102,6 @@ class Model(ModelDesc):
logits = (LinearWrap(image) logits = (LinearWrap(image)
.Conv2D('conv0', 96, 12, stride=4, padding='VALID') .Conv2D('conv0', 96, 12, stride=4, padding='VALID')
.apply(activate) .apply(activate)
.Conv2D('conv1', 256, 5, padding='SAME', split=2) .Conv2D('conv1', 256, 5, padding='SAME', split=2)
.apply(fg) .apply(fg)
.BatchNorm('bn1') .BatchNorm('bn1')
......
...@@ -10,24 +10,12 @@ from abc import abstractmethod, ABCMeta ...@@ -10,24 +10,12 @@ from abc import abstractmethod, ABCMeta
from ..utils import * from ..utils import *
__all__ = ['Callback', 'PeriodicCallback', 'TrainCallbackType', 'TestCallbackType'] __all__ = ['Callback', 'PeriodicCallback']
class TrainCallbackType(object):
pass
class TestCallbackType(object):
pass
class Callback(object): class Callback(object):
""" Base class for all callbacks """ """ Base class for all callbacks """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
type = TrainCallbackType()
""" Determine the graph that this callback should run on.
Either `TrainCallbackType()` or `TestCallbackType()`.
Default is `TrainCallbackType()`
"""
def before_train(self): def before_train(self):
""" """
Called right before the first iteration. Called right before the first iteration.
......
...@@ -6,7 +6,7 @@ import tensorflow as tf ...@@ -6,7 +6,7 @@ import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
import time import time
from .base import Callback, TrainCallbackType, TestCallbackType from .base import Callback
from .stat import * from .stat import *
from ..utils import * from ..utils import *
...@@ -50,9 +50,6 @@ class Callbacks(Callback): ...@@ -50,9 +50,6 @@ class Callbacks(Callback):
# check type # check type
for cb in cbs: for cb in cbs:
assert isinstance(cb, Callback), cb.__class__ assert isinstance(cb, Callback), cb.__class__
if not isinstance(cb.type, (TrainCallbackType, TestCallbackType)):
raise ValueError(
"Unknown callback running graph {}!".format(str(cb.type)))
# move "StatPrinter" to the last # move "StatPrinter" to the last
for cb in cbs: for cb in cbs:
if isinstance(cb, StatPrinter): if isinstance(cb, StatPrinter):
...@@ -62,24 +59,15 @@ class Callbacks(Callback): ...@@ -62,24 +59,15 @@ class Callbacks(Callback):
break break
self.cbs = cbs self.cbs = cbs
self.test_callback_context = TestCallbackContext()
def _setup_graph(self): def _setup_graph(self):
with tf.name_scope(None): with tf.name_scope(None):
for cb in self.cbs: for cb in self.cbs:
if isinstance(cb.type, TrainCallbackType): cb.setup_graph(self.trainer)
cb.setup_graph(self.trainer)
else:
with self.test_callback_context.create_context(self.trainer):
cb.setup_graph(self.trainer)
def _before_train(self): def _before_train(self):
for cb in self.cbs: for cb in self.cbs:
if isinstance(cb.type, TrainCallbackType): cb.before_train()
cb.before_train()
else:
with self.test_callback_context.test_context():
cb.before_train()
def _after_train(self): def _after_train(self):
for cb in self.cbs: for cb in self.cbs:
...@@ -87,9 +75,7 @@ class Callbacks(Callback): ...@@ -87,9 +75,7 @@ class Callbacks(Callback):
def trigger_step(self): def trigger_step(self):
for cb in self.cbs: for cb in self.cbs:
if isinstance(cb.type, TrainCallbackType): cb.trigger_step()
cb.trigger_step()
# test callback don't have trigger_step
def _trigger_epoch(self): def _trigger_epoch(self):
tm = CallbackTimeLogger() tm = CallbackTimeLogger()
...@@ -97,15 +83,6 @@ class Callbacks(Callback): ...@@ -97,15 +83,6 @@ class Callbacks(Callback):
test_sess_restored = False test_sess_restored = False
for cb in self.cbs: for cb in self.cbs:
display_name = str(cb) display_name = str(cb)
if isinstance(cb.type, TrainCallbackType): with tm.timed_callback(display_name):
with tm.timed_callback(display_name): cb.trigger_epoch()
cb.trigger_epoch()
else:
if not test_sess_restored:
with tm.timed_callback('restore checkpoint'):
self.test_callback_context.restore_checkpoint()
test_sess_restored = True
with self.test_callback_context.test_context(), \
tm.timed_callback(display_name):
cb.trigger_epoch()
tm.log() tm.log()
...@@ -70,4 +70,8 @@ class LinearWrap(object): ...@@ -70,4 +70,8 @@ class LinearWrap(object):
def tensor(self): def tensor(self):
return self._t return self._t
def print(self):
print(self._t)
return self
...@@ -55,7 +55,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -55,7 +55,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
batch_var = tf.identity(batch_var, 'variance') batch_var = tf.identity(batch_var, 'variance')
emaname = 'EMA' emaname = 'EMA'
ctx = get_current_model_context() ctx = get_current_tower_context()
if use_local_stat is None: if use_local_stat is None:
use_local_stat = ctx.is_training use_local_stat = ctx.is_training
assert use_local_stat == ctx.is_training assert use_local_stat == ctx.is_training
...@@ -73,17 +73,23 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -73,17 +73,23 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
else: else:
assert not use_local_stat assert not use_local_stat
with tf.name_scope(None): with tf.name_scope(None):
# figure out the var name
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
if ctx.is_main_tower:
# not training, but main tower. need to create the vars
with tf.name_scope(None):
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
else:
# use statistics in another tower
G = tf.get_default_graph()
# figure out the var name
mean_var_name = ema.average_name(batch_mean) + ':0' mean_var_name = ema.average_name(batch_mean) + ':0'
var_var_name = ema.average_name(batch_var) + ':0' var_var_name = ema.average_name(batch_var) + ':0'
ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name)
# use statistics in another tower ema_var = ctx.find_tensor_in_main_tower(G, var_var_name)
G = tf.get_default_graph() #logger.info("In prediction, using {} instead of {} for {}".format(
ema_mean = ctx.find_tensor_in_main_tower(G, mean_var_name) #mean_name, ema_mean.name, batch_mean.name))
ema_var = ctx.find_tensor_in_main_tower(G, var_var_name)
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
if use_local_stat: if use_local_stat:
with tf.control_dependencies([ema_apply_op]): with tf.control_dependencies([ema_apply_op]):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment