Commit 884af444 authored by Yuxin Wu's avatar Yuxin Wu

better train/test callback management

parent d731cf7b
......@@ -11,14 +11,20 @@ from abc import abstractmethod, ABCMeta
from ..utils import *
__all__ = ['Callback', 'PeriodicCallback']
__all__ = ['Callback', 'PeriodicCallback', 'TrainCallback', 'TestCallback']
class TrainCallback(object):
pass
class TestCallback(object):
pass
class Callback(object):
__metaclass__ = ABCMeta
running_graph = 'train'
type = TrainCallback()
""" The graph that this callback should run on.
Either 'train' or 'test'
Either TrainCallback or TestCallback
"""
def before_train(self):
......
......@@ -31,4 +31,8 @@ class PeriodicSaver(PeriodicCallback):
global_step=self.global_step)
class MinSaver(Callback):
pass
def __init__(self, monitor_stat):
self.monitor_stat = monitor_stat
def _trigger_epoch(self):
pass
......@@ -6,7 +6,7 @@
import tensorflow as tf
from contextlib import contextmanager
from .base import Callback
from .base import Callback, TrainCallback, TestCallback
from .summary import *
from ..utils import *
......@@ -41,6 +41,12 @@ class CallbackTimeLogger(object):
self.tot += time
self.times.append((name, time))
@contextmanager
def timed_callback(self, name):
s = time.time()
yield
self.add(name, time.time() - s)
def log(self):
""" log the time of some heavy callbacks """
if self.tot < 3:
......@@ -53,119 +59,98 @@ class CallbackTimeLogger(object):
"Callbacks took {:.3f} sec in total. {}".format(
self.tot, ' '.join(msgs)))
class TrainCallbacks(Callback):
def __init__(self, callbacks):
self.cbs = callbacks
for idx, cb in enumerate(self.cbs):
# put SummaryWriter to the beginning
if type(cb) == SummaryWriter:
self.cbs.insert(0, self.cbs.pop(idx))
break
else:
logger.warn("SummaryWriter must be used! Insert a default one automatically.")
self.cbs.insert(0, SummaryWriter())
def _before_train(self):
for cb in self.cbs:
cb.before_train()
def _after_train(self):
for cb in self.cbs:
cb.after_train()
def trigger_step(self):
for cb in self.cbs:
cb.trigger_step()
def _trigger_epoch(self):
tm = CallbackTimeLogger()
for cb in self.cbs:
s = time.time()
cb.trigger_epoch()
tm.add(type(cb).__name__, time.time() - s)
tm.log()
class TestCallbacks(Callback):
class TestCallbackContext(object):
"""
Hold callbacks to be run in testing graph.
Will set a context with testing graph and testing session, for
each test-time callback to run
A class holding the context needed for running TestCallback
"""
def __init__(self, callbacks):
self.cbs = callbacks
def __init__(self):
self.sess = None
def _before_train(self):
def _init_test_sess(self):
with create_test_session() as sess:
self.sess = sess
self.graph = sess.graph
self.saver = tf.train.Saver()
for cb in self.cbs:
cb.before_train()
def _after_train(self):
for cb in self.cbs:
cb.after_train()
@contextmanager
def before_train_context(self):
if self.sess is None:
self._init_test_sess()
with self.graph.as_default(), self.sess.as_default():
yield
def _trigger_epoch(self):
if not self.cbs:
return
tm = CallbackTimeLogger()
# TODO also do this for after_train?
def restore_checkpoint(self):
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver before any TestCallback?")
logger.info(
"Restore checkpoint from {}".format(ckpt.model_checkpoint_path))
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
@contextmanager
def trigger_epoch_context(self):
with self.graph.as_default(), self.sess.as_default():
s = time.time()
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
logger.error(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver?")
return
logger.info(
"Restore checkpoint from {}".format(ckpt.model_checkpoint_path))
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
tm.add('restore session', time.time() - s)
for cb in self.cbs:
s = time.time()
cb.trigger_epoch()
tm.add(type(cb).__name__, time.time() - s)
tm.log()
yield
class Callbacks(Callback):
def __init__(self, cbs):
train_cbs = []
test_cbs = []
# check type
for cb in cbs:
assert isinstance(cb, Callback), cb.__class__
if cb.running_graph == 'test':
test_cbs.append(cb)
elif cb.running_graph == 'train':
train_cbs.append(cb)
else:
if not isinstance(cb.type, (TrainCallback, TestCallback)):
raise ValueError(
"Unknown callback running graph {}!".format(cb.running_graph))
self.train = TrainCallbacks(train_cbs)
if test_cbs:
self.test = TestCallbacks(test_cbs)
"Unknown callback running graph {}!".format(str(cb.type)))
# ensure a SummaryWriter
for idx, cb in enumerate(cbs):
if type(cb) == SummaryWriter:
cbs.insert(0, cbs.pop(idx))
break
else:
self.test = None
logger.warn("SummaryWriter must be used! Insert a default one automatically.")
cbs.insert(0, SummaryWriter())
self.cbs = cbs
self.test_callback_context = TestCallbackContext()
def _before_train(self):
self.train.before_train()
if self.test:
self.test.before_train()
for cb in self.cbs:
if isinstance(cb.type, TrainCallback):
cb.before_train()
else:
with self.test_callback_context.before_train_context():
cb.before_train()
def _after_train(self):
self.train.after_train()
if self.test:
self.test.after_train()
for cb in self.cbs:
cb.after_train()
logger.writer.close()
def trigger_step(self):
self.train.trigger_step()
for cb in self.cbs:
if isinstance(cb.type, TrainCallback):
cb.trigger_step()
# test callback don't have trigger_step
def _trigger_epoch(self):
self.train.trigger_epoch()
if self.test:
self.test.trigger_epoch()
tm = CallbackTimeLogger()
test_sess_restored = False
for cb in self.cbs:
if isinstance(cb.type, TrainCallback):
with tm.timed_callback(type(cb).__name__):
cb.trigger_epoch()
else:
if not test_sess_restored:
with tm.timed_callback('restore checkpoint'):
self.test_callback_context.restore_checkpoint()
test_sess_restored = True
with self.test_callback_context.trigger_epoch_context(), \
tm.timed_callback(type(cb).__name__):
cb.trigger_epoch()
tm.log()
logger.writer.flush()
logger.stat_holder.finalize()
......@@ -10,12 +10,12 @@ from tqdm import tqdm
from ..utils import *
from ..utils.stat import *
from ..utils.summary import *
from .base import PeriodicCallback, Callback
from .base import PeriodicCallback, Callback, TestCallback
__all__ = ['ValidationError', 'ValidationCallback']
class ValidationCallback(PeriodicCallback):
running_graph = 'test'
type = TestCallback()
"""
Basic routine for validation callbacks.
"""
......@@ -48,7 +48,7 @@ class ValidationCallback(PeriodicCallback):
output_vars = self._get_output_vars()
output_vars.append(self.cost_var)
with tqdm(total=self.ds.size()) as pbar:
with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data():
feed = dict(itertools.izip(self.input_vars, dp))
......
......@@ -12,7 +12,9 @@ from ..utils.symbolic_functions import *
__all__ = ['FullyConnected']
@layer_register(summary_activation=True)
def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu):
def FullyConnected(x, out_dim,
W_init=None, b_init=None,
nl=tf.nn.relu, use_bias=True):
x = batch_flatten(x)
in_dim = x.get_shape().as_list()[1]
......@@ -20,9 +22,11 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu):
#W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim)))
W_init = tf.uniform_unit_scaling_initializer()
if b_init is None:
b_init = tf.constant_initializer(0.0)
b_init = tf.constant_initializer()
with tf.device('/cpu:0'):
W = tf.get_variable('W', [in_dim, out_dim], initializer=W_init)
b = tf.get_variable('b', [out_dim], initializer=b_init)
return nl(tf.nn.xw_plus_b(x, W, b), name=tf.get_variable_scope().name + '_output')
if use_bias:
b = tf.get_variable('b', [out_dim], initializer=b_init)
prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W)
return nl(prod, name=tf.get_variable_scope().name + '_output')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment